adamelliotfields commited on
Commit
fd9e8de
1 Parent(s): 7e65847

Toggle image inputs

Browse files
Files changed (2) hide show
  1. app.py +62 -26
  2. lib/inference.py +3 -3
app.py CHANGED
@@ -60,8 +60,9 @@ async def image_prompt_fn(images):
60
  return create_image_dropdown(images)
61
 
62
 
 
 
63
  async def image_select_fn(images, image, i):
64
- # -2 is the lock icon, -1 is None
65
  if i == -2:
66
  return gr.Image(image)
67
  if i == -1:
@@ -82,10 +83,18 @@ async def generate_fn(*args):
82
  prompt = None
83
  if prompt is None or prompt.strip() == "":
84
  raise gr.Error("You must enter a prompt")
 
 
 
 
 
 
 
 
85
  try:
86
  images = await async_call(
87
  generate,
88
- *args,
89
  Info=gr.Info,
90
  Error=gr.Error,
91
  progress=gr.Progress(),
@@ -119,6 +128,10 @@ with gr.Blocks(
119
  block_background_fill_dark=gr.themes.colors.gray.c900,
120
  ),
121
  ) as demo:
 
 
 
 
122
  gr.HTML(read_file("./partials/intro.html"))
123
 
124
  with gr.Accordion(
@@ -312,25 +325,26 @@ with gr.Blocks(
312
 
313
  # img2img tab
314
  with gr.TabItem("🖼️ Image"):
315
- with gr.Row():
316
- image_prompt = gr.Image(
317
- show_share_button=False,
318
- show_label=False,
319
- min_width=320,
320
- format="png",
321
- type="pil",
322
- )
323
- ip_image = gr.Image(
324
- show_share_button=False,
325
- label="IP-Adapter",
326
- min_width=320,
327
- format="png",
328
- type="pil",
329
- )
330
-
331
  with gr.Group():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  with gr.Row():
333
  image_select = gr.Dropdown(
 
334
  choices=[("None", -1)],
335
  label="Gallery Image",
336
  interactive=True,
@@ -338,8 +352,9 @@ with gr.Blocks(
338
  value=-1,
339
  )
340
  ip_image_select = gr.Dropdown(
341
- choices=[("None", -1)],
342
  label="Gallery Image (IP-Adapter)",
 
343
  interactive=True,
344
  filterable=False,
345
  value=-1,
@@ -355,9 +370,19 @@ with gr.Blocks(
355
  )
356
 
357
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
358
  ip_face = gr.Checkbox(
359
  elem_classes=["checkbox"],
360
- label="IP-Adapter Face",
361
  value=False,
362
  )
363
 
@@ -418,7 +443,7 @@ with gr.Blocks(
418
  file_format.change(
419
  lambda f: (gr.Gallery(format=f), gr.Image(format=f), gr.Image(format=f)),
420
  inputs=[file_format],
421
- outputs=[output_images, image_prompt, ip_image],
422
  show_api=False,
423
  )
424
 
@@ -433,7 +458,7 @@ with gr.Blocks(
433
  # lock the input images so you don't lose them when the gallery updates
434
  output_images.change(
435
  gallery_fn,
436
- inputs=[output_images, image_prompt, ip_image],
437
  outputs=[image_select, ip_image_select],
438
  show_api=False,
439
  )
@@ -447,8 +472,8 @@ with gr.Blocks(
447
  )
448
  ip_image_select.change(
449
  image_select_fn,
450
- inputs=[output_images, ip_image, ip_image_select],
451
- outputs=[ip_image],
452
  show_api=False,
453
  )
454
 
@@ -459,7 +484,7 @@ with gr.Blocks(
459
  outputs=[image_select],
460
  show_api=False,
461
  )
462
- ip_image.clear(
463
  image_prompt_fn,
464
  inputs=[output_images],
465
  outputs=[ip_image_select],
@@ -475,6 +500,15 @@ with gr.Blocks(
475
  js="() => { return null; }",
476
  )
477
 
 
 
 
 
 
 
 
 
 
478
  gr.on(
479
  triggers=[generate_btn.click, prompt.submit],
480
  fn=generate_fn,
@@ -485,7 +519,7 @@ with gr.Blocks(
485
  prompt,
486
  negative_prompt,
487
  image_prompt,
488
- ip_image,
489
  ip_face,
490
  lora_1,
491
  lora_1_weight,
@@ -508,6 +542,8 @@ with gr.Blocks(
508
  use_taesd,
509
  use_freeu,
510
  use_clip_skip,
 
 
511
  ],
512
  )
513
 
 
60
  return create_image_dropdown(images)
61
 
62
 
63
+ # handle selecting an image from the gallery
64
+ # -2 is the lock icon, -1 is None
65
  async def image_select_fn(images, image, i):
 
66
  if i == -2:
67
  return gr.Image(image)
68
  if i == -1:
 
83
  prompt = None
84
  if prompt is None or prompt.strip() == "":
85
  raise gr.Error("You must enter a prompt")
86
+
87
+ DISABLE_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT = args[-2:]
88
+ gen_args = list(args[:-2])
89
+ if DISABLE_IMAGE_PROMPT:
90
+ gen_args[2] = None
91
+ if DISABLE_IP_IMAGE_PROMPT:
92
+ gen_args[3] = None
93
+
94
  try:
95
  images = await async_call(
96
  generate,
97
+ *gen_args,
98
  Info=gr.Info,
99
  Error=gr.Error,
100
  progress=gr.Progress(),
 
128
  block_background_fill_dark=gr.themes.colors.gray.c900,
129
  ),
130
  ) as demo:
131
+ # override image inputs without clearing them
132
+ DISABLE_IMAGE_PROMPT = gr.State(False)
133
+ DISABLE_IP_IMAGE_PROMPT = gr.State(False)
134
+
135
  gr.HTML(read_file("./partials/intro.html"))
136
 
137
  with gr.Accordion(
 
325
 
326
  # img2img tab
327
  with gr.TabItem("🖼️ Image"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  with gr.Group():
329
+ with gr.Row():
330
+ image_prompt = gr.Image(
331
+ show_share_button=False,
332
+ label="Initial Image",
333
+ min_width=320,
334
+ format="png",
335
+ type="pil",
336
+ )
337
+ ip_image_prompt = gr.Image(
338
+ show_share_button=False,
339
+ label="IP-Adapter Image",
340
+ min_width=320,
341
+ format="png",
342
+ type="pil",
343
+ )
344
+
345
  with gr.Row():
346
  image_select = gr.Dropdown(
347
+ info="Use an initial image from the gallery",
348
  choices=[("None", -1)],
349
  label="Gallery Image",
350
  interactive=True,
 
352
  value=-1,
353
  )
354
  ip_image_select = gr.Dropdown(
355
+ info="Use an IP-Adapter image from the gallery",
356
  label="Gallery Image (IP-Adapter)",
357
+ choices=[("None", -1)],
358
  interactive=True,
359
  filterable=False,
360
  value=-1,
 
370
  )
371
 
372
  with gr.Row():
373
+ disable_image = gr.Checkbox(
374
+ elem_classes=["checkbox"],
375
+ label="Disable Initial Image",
376
+ value=False,
377
+ )
378
+ disable_ip_image = gr.Checkbox(
379
+ elem_classes=["checkbox"],
380
+ label="Disable IP-Adapter Image",
381
+ value=False,
382
+ )
383
  ip_face = gr.Checkbox(
384
  elem_classes=["checkbox"],
385
+ label="Use IP-Adapter Face",
386
  value=False,
387
  )
388
 
 
443
  file_format.change(
444
  lambda f: (gr.Gallery(format=f), gr.Image(format=f), gr.Image(format=f)),
445
  inputs=[file_format],
446
+ outputs=[output_images, image_prompt, ip_image_prompt],
447
  show_api=False,
448
  )
449
 
 
458
  # lock the input images so you don't lose them when the gallery updates
459
  output_images.change(
460
  gallery_fn,
461
+ inputs=[output_images, image_prompt, ip_image_prompt],
462
  outputs=[image_select, ip_image_select],
463
  show_api=False,
464
  )
 
472
  )
473
  ip_image_select.change(
474
  image_select_fn,
475
+ inputs=[output_images, ip_image_prompt, ip_image_select],
476
+ outputs=[ip_image_prompt],
477
  show_api=False,
478
  )
479
 
 
484
  outputs=[image_select],
485
  show_api=False,
486
  )
487
+ ip_image_prompt.clear(
488
  image_prompt_fn,
489
  inputs=[output_images],
490
  outputs=[ip_image_select],
 
500
  js="() => { return null; }",
501
  )
502
 
503
+ # toggle image prompts by updating session state
504
+ gr.on(
505
+ triggers=[disable_image.input, disable_ip_image.input],
506
+ fn=lambda disable_image, disable_ip_image: (disable_image, disable_ip_image),
507
+ inputs=[disable_image, disable_ip_image],
508
+ outputs=[DISABLE_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT],
509
+ )
510
+
511
+ # generate images
512
  gr.on(
513
  triggers=[generate_btn.click, prompt.submit],
514
  fn=generate_fn,
 
519
  prompt,
520
  negative_prompt,
521
  image_prompt,
522
+ ip_image_prompt,
523
  ip_face,
524
  lora_1,
525
  lora_1_weight,
 
542
  use_taesd,
543
  use_freeu,
544
  use_clip_skip,
545
+ DISABLE_IMAGE_PROMPT,
546
+ DISABLE_IP_IMAGE_PROMPT,
547
  ],
548
  )
549
 
lib/inference.py CHANGED
@@ -99,7 +99,7 @@ def generate(
99
  positive_prompt,
100
  negative_prompt="",
101
  image_prompt=None,
102
- ip_image=None,
103
  ip_face=False,
104
  lora_1=None,
105
  lora_1_weight=0.0,
@@ -144,7 +144,7 @@ def generate(
144
  else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
145
  )
146
 
147
- if ip_image:
148
  IP_ADAPTER = "full-face" if ip_face else "plus"
149
  else:
150
  IP_ADAPTER = ""
@@ -298,7 +298,7 @@ def generate(
298
  if IP_ADAPTER:
299
  # don't resize full-face images since they are usually square crops
300
  size = None if ip_face else (width, height)
301
- kwargs["ip_adapter_image"] = prepare_image(ip_image, size)
302
 
303
  try:
304
  image = pipe(**kwargs).images[0]
 
99
  positive_prompt,
100
  negative_prompt="",
101
  image_prompt=None,
102
+ ip_image_prompt=None,
103
  ip_face=False,
104
  lora_1=None,
105
  lora_1_weight=0.0,
 
144
  else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
145
  )
146
 
147
+ if ip_image_prompt:
148
  IP_ADAPTER = "full-face" if ip_face else "plus"
149
  else:
150
  IP_ADAPTER = ""
 
298
  if IP_ADAPTER:
299
  # don't resize full-face images since they are usually square crops
300
  size = None if ip_face else (width, height)
301
+ kwargs["ip_adapter_image"] = prepare_image(ip_image_prompt, size)
302
 
303
  try:
304
  image = pipe(**kwargs).images[0]