amildravid4292 commited on
Commit
86ffd66
·
verified ·
1 Parent(s): f112774

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -148
app.py CHANGED
@@ -1,27 +1,26 @@
1
  import os
2
- # os.system("pip uninstall -y gradio")
3
- # #os.system('pip install gradio==3.43.1')
4
-
5
  import torch
6
  import torchvision
7
  import torchvision.transforms as transforms
8
  from torch.utils.data import Dataset, DataLoader
9
  import gradio as gr
10
  import sys
11
- import os
12
  import tqdm
13
  sys.path.append(os.path.abspath(os.path.join("", "..")))
14
- import torch
15
  import gc
16
  import warnings
17
  warnings.filterwarnings("ignore")
18
  from PIL import Image
19
- from utils import load_models, save_model_w2w, save_model_for_diffusers
 
20
  from editing import get_direction, debias
21
  from sampling import sample_weights
22
  from lora_w2w import LoRAw2w
23
  from huggingface_hub import snapshot_download
24
- import numpy as np
 
25
 
26
 
27
  global device
@@ -32,11 +31,9 @@ global text_encoder
32
  global tokenizer
33
  global noise_scheduler
34
  global network
35
- global original_image
36
  device = "cuda:0"
37
  generator = torch.Generator(device=device)
38
- from gradio_imageslider import ImageSlider
39
- import spaces
40
 
41
 
42
  models_path = snapshot_download(repo_id="Snapchat/w2w")
@@ -125,12 +122,9 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
125
  global pointy
126
  global wavy
127
  global large
128
- global original_image
129
-
130
 
131
  original_weights = network.proj.clone()
132
 
133
-
134
  #pad to same number of PCs
135
  pcs_original = original_weights.shape[1]
136
  pcs_edits = young.shape[1]
@@ -141,7 +135,7 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
141
  large_pad = torch.cat((large, padding), 1)
142
 
143
 
144
- edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*8e5*large_pad
145
 
146
  generator = generator.manual_seed(seed)
147
  latents = torch.randn(
@@ -197,22 +191,19 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
197
  #reset weights back to original
198
  network.proj = torch.nn.Parameter(original_weights)
199
  network.reset()
200
-
201
- return (original_image, image)
202
 
203
  def sample_then_run():
204
- global original_image
205
  sample_model()
206
  prompt = "sks person"
207
  negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
208
  seed = 5
209
  cfg = 3.0
210
  steps = 50
211
- original_image = inference( prompt, negative_prompt, cfg, steps, seed)
212
  torch.save(network.proj, "model.pt" )
213
-
214
-
215
- return (original_image, original_image), "model.pt"
216
 
217
 
218
  global young
@@ -275,14 +266,10 @@ class CustomImageDataset(Dataset):
275
  image = self.transform(image)
276
  return image
277
 
278
- def invert(dict, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
279
  global unet
280
  del unet
281
  global network
282
-
283
- image = dict["background"].convert("RGB").resize((512, 512))
284
- mask = dict["layers"][0].convert("RGB").resize((512, 512))
285
-
286
  unet, _, _, _, _ = load_models(device)
287
 
288
  proj = torch.zeros(1,pcs).bfloat16().to(device)
@@ -294,18 +281,13 @@ def invert(dict, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
294
  train_method="xattn-strict"
295
  ).to(device, torch.bfloat16)
296
 
297
-
298
-
299
-
300
-
301
  ### load mask
302
  mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
303
  mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
304
  ### check if an actual mask was draw, otherwise mask is just all ones
305
  if torch.sum(mask) == 0:
306
  mask = torch.ones((1,1,64,64)).to(device).bfloat16()
307
-
308
-
309
  ### single image dataset
310
  image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
311
  transforms.RandomCrop(512),
@@ -313,11 +295,9 @@ def invert(dict, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
313
  transforms.Normalize([0.5], [0.5])])
314
 
315
 
316
- train_dataset = CustomImageDataset([image], transform=image_transforms)
317
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
318
 
319
-
320
-
321
  ### optimizer
322
  optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
323
 
@@ -347,40 +327,34 @@ def invert(dict, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
347
  optim.step()
348
 
349
  ### return optimized network
350
-
351
  return network
352
 
353
 
354
 
355
-
356
  def run_inversion(dict, pcs, epochs, weight_decay,lr):
357
  global network
358
- global original_image
359
- # init_image = dict["image"].convert("RGB").resize((512, 512))
360
- # mask = dict["ma print(dict)
361
- network = invert( dict, pcs, epochs, weight_decay,lr)
362
 
363
 
364
  #sample an image
365
  prompt = "sks person"
366
- negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
367
  seed = 5
368
  cfg = 3.0
369
  steps = 50
370
- original_image = inference( prompt, negative_prompt, cfg, steps, seed)
371
  torch.save(network.proj, "model.pt" )
372
- return (original_image, original_image), "model.pt"
 
373
 
374
-
375
-
376
 
377
  def file_upload(file):
378
  global unet
379
  del unet
380
  global network
381
  global device
382
- global original_image
383
-
384
 
385
 
386
 
@@ -393,39 +367,38 @@ def file_upload(file):
393
 
394
  unet, _, _, _, _ = load_models(device)
395
 
396
-
397
- network = LoRAw2w( proj, mean, std, v[:, :10000],
398
  unet,
399
  rank=1,
400
  multiplier=1.0,
401
  alpha=27.0,
402
  train_method="xattn-strict"
403
  ).to(device, torch.bfloat16)
404
-
405
 
406
  prompt = "sks person"
407
- negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
408
  seed = 5
409
  cfg = 3.0
410
  steps = 50
411
- original_image = inference( prompt, negative_prompt, cfg, steps, seed)
412
- return (original_image, original_image)
413
-
414
-
415
 
416
 
 
417
 
418
 
419
 
420
 
421
  intro = """
422
  <div style="display: flex;align-items: center;justify-content: center">
423
- <h2 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models (aka <b> <em>weights2weights</em></b>)</h2>
 
424
  </div>
425
  <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
426
- <a href="https://snap-research.github.io/weights2weights/" target="_blank">Project Page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">Paper</a>
427
  |
428
- <a href="https://github.com/snap-research/weights2weights" target="_blank">Code</a> |
429
  <a href="https://huggingface.co/spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style="
430
  display: inline-block;
431
  ">
@@ -437,115 +410,86 @@ intro = """
437
 
438
  with gr.Blocks(css="style.css") as demo:
439
  gr.HTML(intro)
440
- with gr.Tab("Model Editing"):
441
- gr.Markdown("""
442
- Click the `Sample New Model` to sample a new identity-encoding model or upload a model to get started ✨
443
- """)
444
- with gr.Column():
445
- with gr.Row():
446
- with gr.Column():
 
 
447
  sample = gr.Button("🎲 Sample New Model")
448
- file_output1 = gr.File(label="Download Sampled Model", container=True, interactive=False)
449
- file_input = gr.File(label="Upload Model", container=True)
450
-
451
-
452
- with gr.Column():
453
- image_slider1 = ImageSlider(position=0.5, type="pil", height=512, width=512, label= "Reference Identity | Generated Samples by User")
454
-
455
- prompt1 = gr.Textbox(label="Prompt",
456
- info="Make sure to include 'sks person'" ,
457
- placeholder="sks person",
458
- value="sks person")
459
- seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
460
-
461
 
 
 
 
 
 
 
462
 
463
-
464
- with gr.Row():
465
- a1_1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
466
- a2_1 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
467
- with gr.Row():
468
- a3_1 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
469
- a4_1 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
470
-
471
-
472
-
473
- with gr.Accordion("Advanced Options", open=False):
474
- cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
475
- steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
476
- negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
477
- injection_step1 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
478
-
479
-
480
-
481
-
482
- submit1 = gr.Button("Generate")
483
-
484
- with gr.Tab("Inversion"):
485
- gr.Markdown("""
486
- Upload an image and optionally define a mask by drawing over the face. Then click `invert` to get started ✨
487
- """)
488
- with gr.Row():
489
  with gr.Column():
490
- input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Upload image and draw to define mask", height=512, width=512, brush=gr.Brush(), layers=False)
491
- lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
492
- pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
493
- epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
494
- weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
495
- invert_button = gr.Button("Invert")
496
- file_output2 = gr.File(label="Download Inverted Model", container=True, interactive=False)
497
 
498
- with gr.Column():
499
- image_slider2 = ImageSlider(position=0.5, type="pil", height=512, width=512, label= "Reference Identity | Generated Samples by User")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
- prompt2 = gr.Textbox(label="Prompt",
502
- info="Make sure to include 'sks person'" ,
503
- placeholder="sks person",
504
- value="sks person")
505
- seed2 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
506
 
507
-
508
-
509
 
510
- with gr.Row():
511
- a1_2 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
512
- a2_2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
513
- with gr.Row():
514
- a3_2 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
515
- a4_2 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
516
 
517
-
518
-
519
- with gr.Accordion("Advanced Options", open=False):
520
- cfg2= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
521
- steps2 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
522
- negative_prompt2 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
523
- injection_step2 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
524
-
525
-
526
-
527
 
528
- submit2 = gr.Button("Generate")
529
 
 
 
530
 
531
 
532
-
533
-
534
-
 
 
 
 
 
 
 
 
535
 
536
 
537
 
538
-
539
- sample.click(fn=sample_then_run, outputs=[image_slider1, file_output1])
540
- submit1.click(fn=edit_inference, inputs=[ prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step1, a1_1, a2_1, a3_1, a4_1], outputs=image_slider1)
541
- file_input.change(fn=file_upload, inputs=file_input, outputs = image_slider1)
542
 
543
 
544
- invert_button.click(fn=run_inversion, inputs=[input_image, pcs, epochs, weight_decay,lr], outputs = [image_slider2, file_output2])
545
- submit2.click(fn=edit_inference, inputs=[ prompt2, negative_prompt2, cfg2, steps2, seed2, injection_step2, a1_2, a2_2, a3_2, a4_2], outputs=image_slider2)
546
 
547
 
548
-
549
 
550
 
551
 
 
1
  import os
2
+ os.system("pip uninstall -y gradio")
3
+ os.system('pip install gradio==3.43.1')
 
4
  import torch
5
  import torchvision
6
  import torchvision.transforms as transforms
7
  from torch.utils.data import Dataset, DataLoader
8
  import gradio as gr
9
  import sys
 
10
  import tqdm
11
  sys.path.append(os.path.abspath(os.path.join("", "..")))
 
12
  import gc
13
  import warnings
14
  warnings.filterwarnings("ignore")
15
  from PIL import Image
16
+ import numpy as np
17
+ from utils import load_models
18
  from editing import get_direction, debias
19
  from sampling import sample_weights
20
  from lora_w2w import LoRAw2w
21
  from huggingface_hub import snapshot_download
22
+ import spaces
23
+
24
 
25
 
26
  global device
 
31
  global tokenizer
32
  global noise_scheduler
33
  global network
 
34
  device = "cuda:0"
35
  generator = torch.Generator(device=device)
36
+
 
37
 
38
 
39
  models_path = snapshot_download(repo_id="Snapchat/w2w")
 
122
  global pointy
123
  global wavy
124
  global large
 
 
125
 
126
  original_weights = network.proj.clone()
127
 
 
128
  #pad to same number of PCs
129
  pcs_original = original_weights.shape[1]
130
  pcs_edits = young.shape[1]
 
135
  large_pad = torch.cat((large, padding), 1)
136
 
137
 
138
+ edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*large_pad
139
 
140
  generator = generator.manual_seed(seed)
141
  latents = torch.randn(
 
191
  #reset weights back to original
192
  network.proj = torch.nn.Parameter(original_weights)
193
  network.reset()
194
+
195
+ return image
196
 
197
  def sample_then_run():
 
198
  sample_model()
199
  prompt = "sks person"
200
  negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
201
  seed = 5
202
  cfg = 3.0
203
  steps = 50
204
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
205
  torch.save(network.proj, "model.pt" )
206
+ return image, "model.pt"
 
 
207
 
208
 
209
  global young
 
266
  image = self.transform(image)
267
  return image
268
 
269
+ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
270
  global unet
271
  del unet
272
  global network
 
 
 
 
273
  unet, _, _, _, _ = load_models(device)
274
 
275
  proj = torch.zeros(1,pcs).bfloat16().to(device)
 
281
  train_method="xattn-strict"
282
  ).to(device, torch.bfloat16)
283
 
 
 
 
 
284
  ### load mask
285
  mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
286
  mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
287
  ### check if an actual mask was draw, otherwise mask is just all ones
288
  if torch.sum(mask) == 0:
289
  mask = torch.ones((1,1,64,64)).to(device).bfloat16()
290
+
 
291
  ### single image dataset
292
  image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
293
  transforms.RandomCrop(512),
 
295
  transforms.Normalize([0.5], [0.5])])
296
 
297
 
298
+ train_dataset = CustomImageDataset(image, transform=image_transforms)
299
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
300
 
 
 
301
  ### optimizer
302
  optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
303
 
 
327
  optim.step()
328
 
329
  ### return optimized network
 
330
  return network
331
 
332
 
333
 
 
334
  def run_inversion(dict, pcs, epochs, weight_decay,lr):
335
  global network
336
+ init_image = dict["image"].convert("RGB").resize((512, 512))
337
+ mask = dict["mask"].convert("RGB").resize((512, 512))
338
+ network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
 
339
 
340
 
341
  #sample an image
342
  prompt = "sks person"
343
+ negative_prompt = "low quality, blurry, unfinished, nudity"
344
  seed = 5
345
  cfg = 3.0
346
  steps = 50
347
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
348
  torch.save(network.proj, "model.pt" )
349
+ return image, "model.pt"
350
+
351
 
 
 
352
 
353
  def file_upload(file):
354
  global unet
355
  del unet
356
  global network
357
  global device
 
 
358
 
359
 
360
 
 
367
 
368
  unet, _, _, _, _ = load_models(device)
369
 
370
+
371
+ network = LoRAw2w( proj, mean, std, v[:, :pcs],
372
  unet,
373
  rank=1,
374
  multiplier=1.0,
375
  alpha=27.0,
376
  train_method="xattn-strict"
377
  ).to(device, torch.bfloat16)
378
+
379
 
380
  prompt = "sks person"
381
+ negative_prompt = "low quality, blurry, unfinished, nudity"
382
  seed = 5
383
  cfg = 3.0
384
  steps = 50
385
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
386
+ return image
 
 
387
 
388
 
389
+
390
 
391
 
392
 
393
 
394
  intro = """
395
  <div style="display: flex;align-items: center;justify-content: center">
396
+ <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
397
+ <h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models</h3>
398
  </div>
399
  <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
400
+ <a href="https://snap-research.github.io/weights2weights/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">paper</a>
401
  |
 
402
  <a href="https://huggingface.co/spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style="
403
  display: inline-block;
404
  ">
 
410
 
411
  with gr.Blocks(css="style.css") as demo:
412
  gr.HTML(intro)
413
+
414
+ gr.Markdown("""<div style="text-align: justify;"> Click below to sample an identity-encoding model, or upload an image below and click \"invert\". You can also optionally draw over the face to define a mask. To use model previously downloaded from this demo see \"Uplaoding a model\" in the Advanced options""")
415
+ with gr.Column():
416
+ with gr.Row():
417
+ with gr.Column():
418
+ input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
419
+ height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
420
+
421
+ with gr.Row():
422
  sample = gr.Button("🎲 Sample New Model")
423
+ invert_button = gr.Button("⬆️ Invert")
424
+ with gr.Column():
425
+ gallery = gr.Image(label="Image",height=512, width=512, interactive=False)
 
 
 
 
 
 
 
 
 
 
426
 
427
+ prompt = gr.Textbox(label="Prompt",
428
+ info="Make sure to include 'sks person'" ,
429
+ placeholder="sks person",
430
+ value="sks person")
431
+
432
+ seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
433
 
434
+ # Editing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  with gr.Column():
436
+ with gr.Row():
437
+ a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
438
+ a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
439
+ with gr.Row():
440
+ a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
441
+ a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
 
442
 
443
+
444
+ with gr.Accordion("Advanced Options", open=False):
445
+ with gr.Tab("Inversion"):
446
+ with gr.Row():
447
+ lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
448
+ pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
449
+ with gr.Row():
450
+ epochs = gr.Slider(label="Epochs", value=800, step=1, minimum=1, maximum=2000, interactive=True)
451
+ weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
452
+ with gr.Tab("Sampling"):
453
+ with gr.Row():
454
+ cfg= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
455
+ steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
456
+ with gr.Row():
457
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
458
+ injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
459
 
460
+ with gr.Tab("Uploading a model"):
461
+ gr.Markdown("""<div style="text-align: justify;">Upload a model below downloaded from this demo.""")
 
 
 
462
 
463
+ file_input = gr.File(label="Upload Model", container=True)
 
464
 
465
+ submit = gr.Button("Generate")
466
+
 
 
 
 
467
 
 
 
 
 
 
 
 
 
 
 
468
 
469
+ gr.Markdown("""<div style="text-align: justify;"> After sampling a new model or inverting, you can download the model below.""")
470
 
471
+ with gr.Row():
472
+ file_output = gr.File(label="Download Sampled Model", container=True, interactive=False)
473
 
474
 
475
+
476
+
477
+
478
+ invert_button.click(fn=run_inversion,
479
+ inputs=[input_image, pcs, epochs, weight_decay,lr],
480
+ outputs = [gallery, file_output])
481
+ sample.click(fn=sample_then_run, outputs=[gallery, file_output])
482
+ submit.click(
483
+ fn=edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[gallery]
484
+ )
485
+ file_input.change(fn=file_upload, inputs=file_input, outputs = input_image)
486
 
487
 
488
 
 
 
 
 
489
 
490
 
 
 
491
 
492
 
 
493
 
494
 
495