fffiloni commited on
Commit
2053864
1 Parent(s): 67e0c9d

add SDXL Turbo option

Browse files
Files changed (1) hide show
  1. app.py +23 -3
app.py CHANGED
@@ -1,8 +1,11 @@
1
  import gradio as gr
2
  from gradio_client import Client
3
  import os
 
 
4
 
5
  hf_token = os.environ.get("HF_TKN")
 
6
 
7
  def get_caption(image_in):
8
  client = Client("https://fffiloni-moondream1.hf.space/", hf_token=hf_token)
@@ -37,12 +40,29 @@ def get_sdxl_lightning(prompt):
37
  print(result)
38
  return result
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def infer(image_in, chosen_method):
41
  caption = get_caption(image_in)
42
  if chosen_method == "LCM" :
43
  img_var = get_lcm(caption)
44
  elif chosen_method == "SDXL Lightning" :
45
  img_var = get_sdxl_lightning(caption)
 
 
46
  return img_var
47
 
48
  gr.Interface(
@@ -51,14 +71,14 @@ gr.Interface(
51
  fn = infer,
52
  inputs = [
53
  gr.Image(type="filepath", label="Image input"),
54
- gr.Dropdown(label="Choose a model", choices=["LCM", "SDXL Lightning"], value="SDXL Lightning")
55
  ],
56
  outputs = [
57
  gr.Image(label="Image variation")
58
  ],
59
  examples = [
60
- ["examples/frog_clean.jpg", "SDXL Lightning"],
61
- ["examples/martin_pecheur.jpeg", "LCM"],
62
  ["examples/forest_deer.png", "SDXL Lightning"]
63
  ],
64
  cache_examples = False
 
1
  import gradio as gr
2
  from gradio_client import Client
3
  import os
4
+ import numpy as np
5
+ import random
6
 
7
  hf_token = os.environ.get("HF_TKN")
8
+ MAX_SEED = np.iinfo(np.int32).max
9
 
10
  def get_caption(image_in):
11
  client = Client("https://fffiloni-moondream1.hf.space/", hf_token=hf_token)
 
40
  print(result)
41
  return result
42
 
43
+ def get_turbo(prompt):
44
+ seed = random.randint(0, MAX_SEED)
45
+ print(f"SEED: {seed}")
46
+ client = Client("https://diffusers-unofficial-sdxl-turbo-i2i-t2i.hf.space/")
47
+ result = client.predict(
48
+ None, # filepath in 'Webcam' Image component
49
+ prompt, # str in 'parameter_5' Textbox component
50
+ 0.7, # float (numeric value between 0.0 and 1.0) in 'Strength' Slider component
51
+ 4, # float (numeric value between 1 and 10) in 'Steps' Slider component
52
+ seed, # float (numeric value between 0 and MAX_SEED) in 'Seed' Slider component
53
+ api_name="/predict"
54
+ )
55
+ print(result)
56
+ return result
57
+
58
  def infer(image_in, chosen_method):
59
  caption = get_caption(image_in)
60
  if chosen_method == "LCM" :
61
  img_var = get_lcm(caption)
62
  elif chosen_method == "SDXL Lightning" :
63
  img_var = get_sdxl_lightning(caption)
64
+ elif chosen_method == "SDXL Turbo" :
65
+ img_var = get_turbo(caption)
66
  return img_var
67
 
68
  gr.Interface(
 
71
  fn = infer,
72
  inputs = [
73
  gr.Image(type="filepath", label="Image input"),
74
+ gr.Dropdown(label="Choose a model", choices=["LCM", "SDXL Lightning", "SDXL Turbo"], value="SDXL Lightning")
75
  ],
76
  outputs = [
77
  gr.Image(label="Image variation")
78
  ],
79
  examples = [
80
+ ["examples/frog_clean.jpg", "LCM"],
81
+ ["examples/martin_pecheur.jpeg", "SDXL Turbo"],
82
  ["examples/forest_deer.png", "SDXL Lightning"]
83
  ],
84
  cache_examples = False