amildravid4292 commited on
Commit
c15417b
·
verified ·
1 Parent(s): c376f5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -48
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import gradio as gr
2
  import sys
3
  import os
@@ -9,8 +13,9 @@ import warnings
9
  warnings.filterwarnings("ignore")
10
  from PIL import Image
11
  from utils import load_models, save_model_w2w, save_model_for_diffusers
12
- from sampling import sample_weights
13
  from editing import get_direction, debias
 
 
14
  from huggingface_hub import snapshot_download
15
 
16
  global device
@@ -20,11 +25,13 @@ global vae
20
  global text_encoder
21
  global tokenizer
22
  global noise_scheduler
23
-
24
  device = "cuda:0"
25
  generator = torch.Generator(device=device)
26
 
27
 
 
 
28
  models_path = snapshot_download(repo_id="Snapchat/w2w")
29
 
30
  mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
@@ -36,7 +43,7 @@ weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
36
  pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
37
 
38
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
39
- global network
40
 
41
  def sample_model():
42
  global unet
@@ -47,6 +54,9 @@ def sample_model():
47
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
48
 
49
 
 
 
 
50
  @torch.no_grad()
51
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
52
  global device
@@ -94,7 +104,7 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
94
 
95
  image = Image.fromarray((image * 255).round().astype("uint8"))
96
 
97
- return [image]
98
 
99
 
100
 
@@ -173,16 +183,13 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
173
  network.proj = torch.nn.Parameter(original_weights)
174
  network.reset()
175
 
176
- return [image]
177
 
178
 
179
 
180
 
181
  def sample_then_run():
182
-
183
-
184
- sample_model()
185
-
186
  prompt = "sks person"
187
  negative_prompt = "low quality, blurry, unfinished, cartoon"
188
  seed = 5
@@ -192,6 +199,8 @@ def sample_then_run():
192
  return image
193
 
194
 
 
 
195
  #directions
196
  global young
197
  global pointy
@@ -233,6 +242,115 @@ large = debias(large, "Wavy_Hair", df, pinverse, device)
233
  large_max = torch.max(proj@large[0]/(torch.norm(large))**2).item()
234
  large_min = torch.min(proj@large[0]/(torch.norm(large))**2).item()
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  intro = """
238
  <div style="display: flex;align-items: center;justify-content: center">
@@ -249,61 +367,97 @@ intro = """
249
  </p>
250
  """
251
 
252
- with gr.Blocks(css="style.css") as demo:
253
- gr.HTML(intro)
254
- with gr.Row():
255
- with gr.Column():
256
- gallery1 = gr.Gallery(label="Identity from Sampled Model")
257
- sample = gr.Button("Sample New Model")
258
- gallery2 = gr.Gallery(label="Identity from Edited Model")
259
-
260
-
261
- with gr.Row():
262
- with gr.Column():
263
- prompt = gr.Textbox(label="Prompt",
264
- info="Make sure to include 'sks person'" ,
265
- placeholder="sks person",
266
- value="sks person")
267
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
268
- with gr.Row():
269
- a1 = gr.Slider(label="+Young", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
270
- a2 = gr.Slider(label="+Pointy Nose", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
271
- with gr.Row():
272
- a3 = gr.Slider(label="+Curly Hair", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
273
- a4 = gr.Slider(label="+Large", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
274
-
275
-
276
- with gr.Accordion("Advanced Options", open=False):
277
- with gr.Column():
278
- seed = gr.Number(value=5, label="Seed", interactive=True)
279
- cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
280
- steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
281
- injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
282
 
283
 
 
 
 
 
 
 
 
 
284
 
285
- submit = gr.Button("Submit")
286
-
287
-
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
 
 
 
 
 
 
 
290
 
291
 
292
-
293
- sample.click(fn=sample_then_run, outputs=gallery1)
 
 
294
 
 
 
 
 
295
 
296
- submit.click(fn=edit_inference,
297
- inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4],
298
- outputs=gallery2)
 
 
 
299
 
 
 
 
 
 
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
 
 
302
 
 
 
 
303
 
 
 
 
304
 
305
- demo.launch(share=True)
306
 
307
 
 
 
308
 
 
309
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchvision.transforms as transforms
4
+ from torch.utils.data import Dataset, DataLoader
5
  import gradio as gr
6
  import sys
7
  import os
 
13
  warnings.filterwarnings("ignore")
14
  from PIL import Image
15
  from utils import load_models, save_model_w2w, save_model_for_diffusers
 
16
  from editing import get_direction, debias
17
+ from sampling import sample_weights
18
+ from lora_w2w import LoRAw2w
19
  from huggingface_hub import snapshot_download
20
 
21
  global device
 
25
  global text_encoder
26
  global tokenizer
27
  global noise_scheduler
28
+ global network
29
  device = "cuda:0"
30
  generator = torch.Generator(device=device)
31
 
32
 
33
+
34
+
35
  models_path = snapshot_download(repo_id="Snapchat/w2w")
36
 
37
  mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
 
43
  pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
44
 
45
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
46
+
47
 
48
  def sample_model():
49
  global unet
 
54
  network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
55
 
56
 
57
+
58
+
59
+
60
  @torch.no_grad()
61
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
62
  global device
 
104
 
105
  image = Image.fromarray((image * 255).round().astype("uint8"))
106
 
107
+ return image
108
 
109
 
110
 
 
183
  network.proj = torch.nn.Parameter(original_weights)
184
  network.reset()
185
 
186
+ return image
187
 
188
 
189
 
190
 
191
  def sample_then_run():
192
+ sample_model()
 
 
 
193
  prompt = "sks person"
194
  negative_prompt = "low quality, blurry, unfinished, cartoon"
195
  seed = 5
 
199
  return image
200
 
201
 
202
+
203
+
204
  #directions
205
  global young
206
  global pointy
 
242
  large_max = torch.max(proj@large[0]/(torch.norm(large))**2).item()
243
  large_min = torch.min(proj@large[0]/(torch.norm(large))**2).item()
244
 
245
+ class CustomImageDataset(Dataset):
246
+ def __init__(self, images, transform=None):
247
+ self.images = images
248
+ self.transform = transform
249
+
250
+ def __len__(self):
251
+ return len(self.images)
252
+
253
+ def __getitem__(self, idx):
254
+ image = self.images[idx]
255
+ if self.transform:
256
+ image = self.transform(image)
257
+ return image
258
+
259
+ def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
260
+ global unet
261
+ del unet
262
+ global network
263
+ unet, _, _, _, _ = load_models(device)
264
+
265
+ proj = torch.zeros(1,pcs).bfloat16().to(device)
266
+ network = LoRAw2w( proj, mean, std, v[:, :pcs],
267
+ unet,
268
+ rank=1,
269
+ multiplier=1.0,
270
+ alpha=27.0,
271
+ train_method="xattn-strict"
272
+ ).to(device, torch.bfloat16)
273
+
274
+
275
+
276
+
277
+
278
+ ### load mask
279
+ mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
280
+ mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
281
+ ### check if an actual mask was draw, otherwise mask is just all ones
282
+ if torch.sum(mask) == 0:
283
+ mask = torch.ones((1,1,64,64)).to(device).bfloat16()
284
+
285
+
286
+ ### single image dataset
287
+ image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
288
+ transforms.RandomCrop(512),
289
+ transforms.ToTensor(),
290
+ transforms.Normalize([0.5], [0.5])])
291
+
292
+
293
+ train_dataset = CustomImageDataset(image, transform=image_transforms)
294
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
295
+
296
+
297
+
298
+ ### optimizer
299
+ optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
300
+
301
+ ### training loop
302
+ unet.train()
303
+ for epoch in tqdm.tqdm(range(epochs)):
304
+ for batch in train_dataloader:
305
+ ### prepare inputs
306
+ batch = batch.to(device).bfloat16()
307
+ latents = vae.encode(batch).latent_dist.sample()
308
+ latents = latents*0.18215
309
+ noise = torch.randn_like(latents)
310
+ bsz = latents.shape[0]
311
+
312
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
313
+ timesteps = timesteps.long()
314
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
315
+ text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
316
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
317
+
318
+ ### loss + sgd step
319
+ with network:
320
+ model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
321
+ loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
322
+ optim.zero_grad()
323
+ loss.backward()
324
+ optim.step()
325
+
326
+ ### return optimized network
327
+
328
+ return network
329
+
330
+
331
+
332
+ def run_inversion(dict, pcs, epochs, weight_decay,lr):
333
+ global network
334
+ init_image = dict["image"].convert("RGB").resize((512, 512))
335
+ mask = dict["mask"].convert("RGB").resize((512, 512))
336
+ network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
337
+
338
+
339
+ #sample an image
340
+ prompt = "sks person"
341
+ negative_prompt = "low quality, blurry, unfinished, cartoon"
342
+ seed = 5
343
+ cfg = 3.0
344
+ steps = 50
345
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
346
+ torch.save(network.proj, "model.pt" )
347
+ return image, "model.pt"
348
+
349
+
350
+
351
+
352
+
353
+
354
 
355
  intro = """
356
  <div style="display: flex;align-items: center;justify-content: center">
 
367
  </p>
368
  """
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
 
372
+ with gr.Blocks(css="style.css") as demo:
373
+ gr.HTML(intro)
374
+ with gr.Tab("Sampling Models + Editing"):
375
+ with gr.Row():
376
+ with gr.Column():
377
+ gallery1 = gr.Image(label="Identity from Sampled Model")
378
+ sample = gr.Button("Sample New Model")
379
+ gallery2 = gr.Image(label="Identity from Edited Model")
380
 
 
 
 
381
 
382
+ with gr.Row():
383
+ with gr.Column():
384
+ prompt = gr.Textbox(label="Prompt",
385
+ info="Make sure to include 'sks person'" ,
386
+ placeholder="sks person",
387
+ value="sks person")
388
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
389
+ with gr.Row():
390
+ a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
391
+
392
+ a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
393
+ with gr.Row():
394
+ a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
395
+ a4 = gr.Slider(label="- placeholder for some fourth attribute +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
396
 
397
+
398
+ with gr.Accordion("Advanced Options", open=False):
399
+ with gr.Column():
400
+ seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
401
+ cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
402
+ steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
403
+ injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
404
 
405
 
406
+
407
+ submit = gr.Button("Generate")
408
+
409
+ sample.click(fn=sample_then_run, outputs=gallery1)
410
 
411
+ submit.click(fn=edit_inference,
412
+ inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4],
413
+ outputs=gallery2)
414
+
415
 
416
+
417
+ with gr.Tab("Inversion"):
418
+ with gr.Row():
419
+ with gr.Column():
420
+ input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
421
+ height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
422
 
423
+
424
+ lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
425
+ weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
426
+ pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
427
+ epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
428
 
429
+ invert_button = gr.Button("Invert")
430
+
431
+ with gr.Column():
432
+ gallery = gr.Image(label="Sample from Inverted Model", height=512, width=512)
433
+ prompt = gr.Textbox(label="Prompt",
434
+ info="Make sure to include 'sks person'" ,
435
+ placeholder="sks person",
436
+ value="sks person")
437
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
438
+ seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
439
+ cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
440
+ steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
441
+ submit = gr.Button("Generate")
442
+
443
+ file_output = gr.File(label="Download Model", container=False)
444
 
445
+
446
+
447
 
448
+ invert_button.click(fn=run_inversion,
449
+ inputs=[input_image, pcs, epochs, weight_decay,lr],
450
+ outputs = [gallery, file_output])
451
 
452
+ submit.click(fn=inference,
453
+ inputs=[prompt, negative_prompt, cfg, steps, seed,],
454
+ outputs=gallery)
455
 
456
+
457
 
458
 
459
+
460
+
461
 
462
+ demo.queue().launch(share=True)
463