LiruiZhao commited on
Commit
041e06c
1 Parent(s): 13a253a

[Major] Add list generation

Browse files
Files changed (1) hide show
  1. app.py +155 -22
app.py CHANGED
@@ -282,22 +282,138 @@ def generate(
282
 
283
  return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  def reset():
286
  return [100, "Randomize Seed", 1372, "Fix CFG", 7.5, 1.5, None, None, None, None, None, None, None, "Close Image Video", 10]
287
 
 
288
  def get_example():
289
  return [
290
- ["example_images/dufu.png", "black and white suit", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
291
- ["example_images/girl.jpeg", "reflective sunglasses", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
292
- ["example_images/road_sign.png", "stop sign", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
293
- ["example_images/dufu.png", "blue medical mask", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
294
- ["example_images/people_standing.png", "dark green pleated skirt", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
295
- ["example_images/girl.jpeg", "shiny golden crown", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
296
- ["example_images/dufu.png", "sunglasses", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
297
- ["example_images/girl.jpeg", "diamond necklace", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
298
- ["example_images/iron_man.jpg", "sunglasses", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
299
- ["example_images/girl.jpeg", "the queen's crown", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
300
- ["example_images/girl.jpeg", "gorgeous yellow gown", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
 
 
301
  ]
302
 
303
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
@@ -325,7 +441,14 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
325
  with gr.Row():
326
  input_image = gr.Image(label="Input Image", type="pil", interactive=True)
327
  with gr.Row():
328
- instruction = gr.Textbox(lines=1, label="Object description", interactive=True)
 
 
 
 
 
 
 
329
  with gr.Row():
330
  steps = gr.Number(value=100, precision=0, label="Steps", interactive=True)
331
  randomize_seed = gr.Radio(
@@ -347,17 +470,13 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
347
  )
348
  text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
349
  image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
350
- with gr.Row():
351
- reset_button = gr.Button("Reset")
352
- generate_button = gr.Button("Generate")
353
  with gr.Column(scale=1, min_width=100):
354
  with gr.Column():
355
  mix_image = gr.Image(label=f"Mix Image", type="pil", interactive=False)
356
  with gr.Column():
357
  edited_mask = gr.Image(label=f"Output Mask", type="pil", interactive=False)
358
 
359
-
360
- with gr.Accordion('More outputs', open=False):
361
  with gr.Row():
362
  weather_close_video = gr.Radio(
363
  ["Show Image Video", "Close Image Video"],
@@ -374,15 +493,11 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
374
  original_image = gr.Image(label=f"Original Image", type="pil", interactive=False)
375
  edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False)
376
  mix_result_with_red_mask = gr.Image(label=f"Mix Image With Red Mask", type="pil", interactive=False)
377
-
378
 
379
  with gr.Row():
380
  gr.Examples(
381
  examples=get_example(),
382
- fn=generate,
383
- inputs=[input_image, instruction, steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale],
384
- outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
385
- cache_examples=False,
386
  )
387
 
388
  generate_button.click(
@@ -401,6 +516,24 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
401
  ],
402
  outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
403
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  reset_button.click(
405
  fn=reset,
406
  inputs=[],
 
282
 
283
  return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
284
 
285
+ @spaces.GPU(duration=30)
286
+ def generate_list(
287
+ input_image: Image.Image,
288
+ generate_list: str,
289
+ steps: int,
290
+ randomize_seed: bool,
291
+ seed: int,
292
+ randomize_cfg: bool,
293
+ text_cfg_scale: float,
294
+ image_cfg_scale: float,
295
+ weather_close_video: bool,
296
+ decode_image_batch: int
297
+ ):
298
+ generate_list = generate_list.split('\n')
299
+ # Remove the empty element
300
+ generate_list = [element for element in generate_list if element]
301
+
302
+ seed = random.randint(0, 100000) if randomize_seed else seed
303
+ text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale
304
+ image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale
305
+
306
+ width, height = input_image.size
307
+ factor = args.resolution / max(width, height)
308
+ factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
309
+ width = int((width * factor) // 64) * 64
310
+ height = int((height * factor) // 64) * 64
311
+ input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
312
+
313
+ if len(generate_list) == 0:
314
+ return [input_image, seed]
315
+
316
+ model.cuda()
317
+ image_video = [np.array(input_image).astype(np.uint8)]
318
+ generate_index = 0
319
+ input_image_copy = input_image.convert("RGB")
320
+ while generate_index < len(generate_list):
321
+ print(f'generate_index: {str(generate_index)}')
322
+ instruction = generate_list[generate_index]
323
+ with torch.no_grad(), autocast("cuda"), model.ema_scope():
324
+ cond = {}
325
+ input_image_torch = 2 * torch.tensor(np.array(input_image_copy.copy())).float() / 255 - 1
326
+ input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
327
+ cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
328
+ cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
329
+
330
+ uncond = {}
331
+ uncond["c_crossattn"] = [null_token.to(model.device)]
332
+ uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
333
+
334
+ sigmas = model_wrap.get_sigmas(steps).to(model.device)
335
+
336
+ extra_args = {
337
+ "cond": cond,
338
+ "uncond": uncond,
339
+ "text_cfg_scale": text_cfg_scale,
340
+ "image_cfg_scale": image_cfg_scale,
341
+ }
342
+ torch.manual_seed(seed)
343
+ z_0 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
344
+ z_1 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
345
+
346
+ z_0, z_1, _, _ = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args)
347
+
348
+ x_0 = model.decode_first_stage(z_0)
349
+
350
+ x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
351
+ x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
352
+
353
+ if torch.sum(x_1).item()/x_1.numel() < -0.99:
354
+ seed += 1
355
+ continue
356
+ else:
357
+ generate_index += 1
358
+
359
+ x_0 = torch.clamp((x_0 + 1.0) / 2.0, min=0.0, max=1.0)
360
+ x_1 = torch.clamp((x_1 + 1.0) / 2.0, min=0.0, max=1.0)
361
+ x_0 = 255.0 * rearrange(x_0, "1 c h w -> h w c")
362
+ x_1 = 255.0 * rearrange(x_1, "1 c h w -> h w c")
363
+ x_1 = torch.cat([x_1, x_1, x_1], dim=-1)
364
+ edited_image = Image.fromarray(x_0.type(torch.uint8).cpu().numpy())
365
+ edited_mask = Image.fromarray(x_1.type(torch.uint8).cpu().numpy())
366
+
367
+ # 对edited_mask做膨胀
368
+ edited_mask_copy = edited_mask.copy()
369
+ kernel = np.ones((3, 3), np.uint8)
370
+ edited_mask = cv2.dilate(np.array(edited_mask), kernel, iterations=3)
371
+ edited_mask = Image.fromarray(edited_mask)
372
+
373
+ m_img = edited_mask.filter(ImageFilter.GaussianBlur(radius=3))
374
+ m_img = np.asarray(m_img).astype('float') / 255.0
375
+ img_np = np.asarray(input_image_copy).astype('float') / 255.0
376
+ ours_np = np.asarray(edited_image).astype('float') / 255.0
377
+
378
+ mix_image_np = m_img * ours_np + (1 - m_img) * img_np
379
+
380
+ image_video.append((mix_image_np * 255).astype(np.uint8))
381
+ mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB')
382
+ input_image_copy = mix_image
383
+
384
+ mix_result_with_red_mask = None
385
+ mask_video_path = None
386
+ edited_mask_copy = None
387
+
388
+ image_video_path = "image.mp4"
389
+ fps = 2
390
+ with imageio.get_writer(image_video_path, fps=fps) as video:
391
+ for image in image_video:
392
+ video.append_data(image)
393
+
394
+
395
+ return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image, mix_result_with_red_mask]
396
+
397
+
398
  def reset():
399
  return [100, "Randomize Seed", 1372, "Fix CFG", 7.5, 1.5, None, None, None, None, None, None, None, "Close Image Video", 10]
400
 
401
+
402
  def get_example():
403
  return [
404
+ ["example_images/dufu.png", "", "black and white suit\nsunglasses\nblue medical mask\nyellow schoolbag\nred bow tie\nbrown high-top hat", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
405
+ ["example_images/girl.jpeg", "reflective sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
406
+ ["example_images/dufu.png", "black and white suit", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
407
+ ["example_images/girl.jpeg", "reflective sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
408
+ ["example_images/road_sign.png", "stop sign", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
409
+ ["example_images/dufu.png", "blue medical mask", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
410
+ ["example_images/people_standing.png", "dark green pleated skirt", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
411
+ ["example_images/girl.jpeg", "shiny golden crown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
412
+ ["example_images/dufu.png", "sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
413
+ ["example_images/girl.jpeg", "diamond necklace", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
414
+ ["example_images/iron_man.jpg", "sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
415
+ ["example_images/girl.jpeg", "the queen's crown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
416
+ ["example_images/girl.jpeg", "gorgeous yellow gown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
417
  ]
418
 
419
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
 
441
  with gr.Row():
442
  input_image = gr.Image(label="Input Image", type="pil", interactive=True)
443
  with gr.Row():
444
+ instruction = gr.Textbox(lines=1, label="Single object description", interactive=True)
445
+ with gr.Row():
446
+ reset_button = gr.Button("Reset")
447
+ generate_button = gr.Button("Generate")
448
+ with gr.Row():
449
+ list_input = gr.Textbox(label="Input List", placeholder="Enter one item per line", lines=10)
450
+ with gr.Row():
451
+ list_generate_button = gr.Button("List Generate")
452
  with gr.Row():
453
  steps = gr.Number(value=100, precision=0, label="Steps", interactive=True)
454
  randomize_seed = gr.Radio(
 
470
  )
471
  text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
472
  image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
 
 
 
473
  with gr.Column(scale=1, min_width=100):
474
  with gr.Column():
475
  mix_image = gr.Image(label=f"Mix Image", type="pil", interactive=False)
476
  with gr.Column():
477
  edited_mask = gr.Image(label=f"Output Mask", type="pil", interactive=False)
478
 
479
+ with gr.Accordion('Click to see more (includes generation process per object for list generation and per step for single generation)', open=False):
 
480
  with gr.Row():
481
  weather_close_video = gr.Radio(
482
  ["Show Image Video", "Close Image Video"],
 
493
  original_image = gr.Image(label=f"Original Image", type="pil", interactive=False)
494
  edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False)
495
  mix_result_with_red_mask = gr.Image(label=f"Mix Image With Red Mask", type="pil", interactive=False)
 
496
 
497
  with gr.Row():
498
  gr.Examples(
499
  examples=get_example(),
500
+ inputs=[input_image, instruction, list_input, steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale],
 
 
 
501
  )
502
 
503
  generate_button.click(
 
516
  ],
517
  outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
518
  )
519
+
520
+ list_generate_button.click(
521
+ fn=generate_list,
522
+ inputs=[
523
+ input_image,
524
+ list_input,
525
+ steps,
526
+ randomize_seed,
527
+ seed,
528
+ randomize_cfg,
529
+ text_cfg_scale,
530
+ image_cfg_scale,
531
+ weather_close_video,
532
+ decode_image_batch
533
+ ],
534
+ outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
535
+ )
536
+
537
  reset_button.click(
538
  fn=reset,
539
  inputs=[],