LiruiZhao commited on
Commit
900a405
1 Parent(s): 8f2f7d0

[Minor] debug list generation

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -282,7 +282,7 @@ def generate(
282
 
283
  return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
284
 
285
- @spaces.GPU(duration=30)
286
  def generate_list(
287
  input_image: Image.Image,
288
  generate_list: str,
@@ -316,6 +316,8 @@ def generate_list(
316
  model.cuda()
317
  image_video = [np.array(input_image).astype(np.uint8)]
318
  generate_index = 0
 
 
319
  input_image_copy = input_image.convert("RGB")
320
  while generate_index < len(generate_list):
321
  print(f'generate_index: {str(generate_index)}')
@@ -352,6 +354,9 @@ def generate_list(
352
 
353
  if torch.sum(x_1).item()/x_1.numel() < -0.99:
354
  seed += 1
 
 
 
355
  continue
356
  else:
357
  generate_index += 1
@@ -402,7 +407,7 @@ def reset():
402
  def get_example():
403
  return [
404
  ["example_images/dufu.png", "", "black and white suit\nsunglasses\nblue medical mask\nyellow schoolbag\nred bow tie\nbrown high-top hat", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
405
- ["example_images/girl.jpeg", "reflective sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
406
  ["example_images/dufu.png", "black and white suit", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
407
  ["example_images/girl.jpeg", "reflective sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
408
  ["example_images/road_sign.png", "stop sign", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
@@ -497,7 +502,10 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
497
  with gr.Row():
498
  gr.Examples(
499
  examples=get_example(),
500
- inputs=[input_image, instruction, list_input, steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale],
 
 
 
501
  )
502
 
503
  generate_button.click(
 
282
 
283
  return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
284
 
285
+ @spaces.GPU(duration=180)
286
  def generate_list(
287
  input_image: Image.Image,
288
  generate_list: str,
 
316
  model.cuda()
317
  image_video = [np.array(input_image).astype(np.uint8)]
318
  generate_index = 0
319
+ retry_number = 0
320
+ max_retry = 10
321
  input_image_copy = input_image.convert("RGB")
322
  while generate_index < len(generate_list):
323
  print(f'generate_index: {str(generate_index)}')
 
354
 
355
  if torch.sum(x_1).item()/x_1.numel() < -0.99:
356
  seed += 1
357
+ retry_number +=1
358
+ if retry_number > max_retry:
359
+ generate_index += 1
360
  continue
361
  else:
362
  generate_index += 1
 
407
  def get_example():
408
  return [
409
  ["example_images/dufu.png", "", "black and white suit\nsunglasses\nblue medical mask\nyellow schoolbag\nred bow tie\nbrown high-top hat", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
410
+ ["example_images/girl.jpeg", "", "reflective sunglasses\nshiny golden crown\ndiamond necklace\ngorgeous yellow gown", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
411
  ["example_images/dufu.png", "black and white suit", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
412
  ["example_images/girl.jpeg", "reflective sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
413
  ["example_images/road_sign.png", "stop sign", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
 
502
  with gr.Row():
503
  gr.Examples(
504
  examples=get_example(),
505
+ inputs=[input_image, instruction, list_input, steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale, weather_close_video, decode_image_batch],
506
+ fn=None,
507
+ outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
508
+ cache_examples = False
509
  )
510
 
511
  generate_button.click(