adamelliotfields commited on
Commit
297482a
·
verified ·
1 Parent(s): 612684c

Add initial Fal txt2img models

Browse files
Files changed (4) hide show
  1. lib/api.py +27 -16
  2. lib/config.py +69 -13
  3. lib/presets.py +102 -7
  4. pages/2_🎨_Text_to_Image.py +84 -45
lib/api.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import io
2
 
3
  import requests
@@ -24,25 +25,35 @@ def txt2txt_generate(api_key, service, model, parameters, **kwargs):
24
 
25
 
26
  def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
27
- headers = {
28
- "Authorization": f"Bearer {api_key}",
29
- "X-Wait-For-Model": "true",
30
- "X-Use-Cache": "false",
31
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  base_url = f"{Config.SERVICES[service]}/{model}"
33
 
34
  try:
35
- response = requests.post(
36
- base_url,
37
- headers=headers,
38
- json={
39
- "inputs": inputs,
40
- "parameters": {**parameters, **kwargs},
41
- },
42
- )
43
- if response.status_code == 200:
44
- return Image.open(io.BytesIO(response.content))
45
  else:
46
- return f"Error: {response.status_code} - {response.text}"
47
  except Exception as e:
48
  return str(e)
 
1
+ import base64
2
  import io
3
 
4
  import requests
 
25
 
26
 
27
  def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
28
+ headers = {}
29
+ if service == "Huggingface":
30
+ headers["Authorization"] = f"Bearer {api_key}"
31
+ headers["X-Wait-For-Model"] = "true"
32
+ headers["X-Use-Cache"] = "false"
33
+ if service == "Fal":
34
+ headers["Authorization"] = f"Key {api_key}"
35
+
36
+ json = {}
37
+ if service == "Huggingface":
38
+ json = {
39
+ "inputs": inputs,
40
+ "parameters": {**parameters, **kwargs},
41
+ }
42
+ if service == "Fal":
43
+ json = {**parameters, **kwargs}
44
+ json["prompt"] = inputs
45
+
46
  base_url = f"{Config.SERVICES[service]}/{model}"
47
 
48
  try:
49
+ response = requests.post(base_url, headers=headers, json=json)
50
+ if response.status_code // 100 == 2: # 2xx
51
+ if service == "Huggingface":
52
+ return Image.open(io.BytesIO(response.content))
53
+ if service == "Fal":
54
+ bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
55
+ return Image.open(io.BytesIO(bytes))
 
 
 
56
  else:
57
+ return f"Error: {response.status_code} {response.text}"
58
  except Exception as e:
59
  return str(e)
lib/config.py CHANGED
@@ -7,22 +7,78 @@ Config = SimpleNamespace(
7
  SERVICES={
8
  "Huggingface": "https://api-inference.huggingface.co/models",
9
  "Perplexity": "https://api.perplexity.ai",
 
10
  },
11
- TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, malformed, mutated, disgusting, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
12
- TXT2IMG_DEFAULT_MODEL=2,
13
- TXT2IMG_MODELS=[
14
- "black-forest-labs/flux.1-dev",
15
- "black-forest-labs/flux.1-schnell",
16
- "stabilityai/stable-diffusion-xl-base-1.0",
 
 
 
 
 
 
17
  ],
18
- TXT2IMG_DEFAULT_AR="1:1",
19
- TXT2IMG_AR={
20
- "7:4": (1344, 768),
21
- "9:7": (1152, 896),
22
- "1:1": (1024, 1024),
23
- "7:9": (896, 1152),
24
- "4:7": (768, 1344),
 
 
 
 
 
 
 
 
 
 
 
 
25
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
27
  TXT2TXT_DEFAULT_MODEL={
28
  "Huggingface": 4,
 
7
  SERVICES={
8
  "Huggingface": "https://api-inference.huggingface.co/models",
9
  "Perplexity": "https://api.perplexity.ai",
10
+ "Fal": "https://fal.run",
11
  },
12
+ TXT2IMG_HIDDEN_PARAMETERS=[
13
+ "enable_safety_checker",
14
+ "max_sequence_length",
15
+ "num_images",
16
+ "output_format",
17
+ "performance",
18
+ "safety_tolerance",
19
+ "scheduler",
20
+ "sharpness",
21
+ "style",
22
+ "styles",
23
+ "sync_mode",
24
  ],
25
+ TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, malformed, mutated, disgusting, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
26
+ TXT2IMG_DEFAULT_MODEL={
27
+ "Huggingface": 2,
28
+ "Fal": 1,
29
+ },
30
+ TXT2IMG_MODELS={
31
+ "Huggingface": [
32
+ "black-forest-labs/flux.1-dev",
33
+ "black-forest-labs/flux.1-schnell",
34
+ "stabilityai/stable-diffusion-xl-base-1.0",
35
+ ],
36
+ "Fal": [
37
+ # "fal-ai/aura-flow",
38
+ # "fal-ai/flux-pro",
39
+ "fal-ai/fooocus",
40
+ "fal-ai/kolors",
41
+ "fal-ai/pixart-sigma",
42
+ "fal-ai/stable-diffusion-v3-medium",
43
+ ],
44
  },
45
+ TXT2IMG_DEFAULT_IMAGE_SIZE="square_hd", # fal image sizes
46
+ TXT2IMG_IMAGE_SIZES=[
47
+ "landscape_16_9",
48
+ "landscape_4_3",
49
+ "square_hd",
50
+ "portrait_4_3",
51
+ "portrait_16_9",
52
+ ],
53
+ TXT2IMG_DEFAULT_ASPECT_RATIO="1024x1024", # fooocus aspect ratios
54
+ TXT2IMG_ASPECT_RATIOS=[
55
+ "704x1408",
56
+ "704x1344",
57
+ "768x1344",
58
+ "768x1280",
59
+ "832x1216",
60
+ "832x1152",
61
+ "896x1152",
62
+ "896x1088",
63
+ "960x1088",
64
+ "960x1024",
65
+ "1024x1024",
66
+ "1024x960",
67
+ "1088x960",
68
+ "1088x896",
69
+ "1152x896",
70
+ "1152x832",
71
+ "1216x832",
72
+ "1280x768",
73
+ "1344x768",
74
+ "1344x704",
75
+ "1408x704",
76
+ "1472x704",
77
+ "1536x640",
78
+ "1600x640",
79
+ "1664x576",
80
+ "1728x576",
81
+ ],
82
  TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
83
  TXT2TXT_DEFAULT_MODEL={
84
  "Huggingface": 4,
lib/presets.py CHANGED
@@ -19,18 +19,40 @@ ServicePresets = SimpleNamespace(
19
 
20
  # txt2img models
21
  ModelPresets = SimpleNamespace(
22
- FLUX_1_DEV={
 
 
 
 
 
 
 
 
 
 
 
23
  "name": "FLUX.1 Dev",
24
- "num_inference_steps": 30,
25
  "num_inference_steps_min": 10,
26
- "num_inference_steps_max": 40,
27
  "guidance_scale": 3.5,
28
  "guidance_scale_min": 1.0,
29
- "guidance_scale_max": 7.0,
30
  "parameters": ["width", "height", "guidance_scale", "num_inference_steps"],
31
  "kwargs": {"max_sequence_length": 512},
32
  },
33
- FLUX_1_SCHNELL={
 
 
 
 
 
 
 
 
 
 
 
34
  "name": "FLUX.1 Schnell",
35
  "num_inference_steps": 4,
36
  "num_inference_steps_min": 1,
@@ -38,14 +60,87 @@ ModelPresets = SimpleNamespace(
38
  "parameters": ["width", "height", "num_inference_steps"],
39
  "kwargs": {"guidance_scale": 0.0, "max_sequence_length": 256},
40
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  STABLE_DIFFUSION_XL={
42
  "name": "SDXL",
43
  "guidance_scale": 7.0,
44
  "guidance_scale_min": 1.0,
45
- "guidance_scale_max": 15.0,
46
  "num_inference_steps": 40,
47
  "num_inference_steps_min": 10,
48
  "num_inference_steps_max": 50,
49
- "parameters": ["width", "height", "guidance_scale", "num_inference_steps", "seed", "negative_prompt"],
50
  },
51
  )
 
19
 
20
  # txt2img models
21
  ModelPresets = SimpleNamespace(
22
+ AURA_FLOW={
23
+ "name": "AuraFlow",
24
+ "guidance_scale": 3.5,
25
+ "guidance_scale_min": 1.0,
26
+ "guidance_scale_max": 10.0,
27
+ "num_inference_steps": 50,
28
+ "num_inference_steps_min": 10,
29
+ "num_inference_steps_max": 50,
30
+ "parameters": ["seed", "guidance_scale", "num_inference_steps", "expand_prompt"],
31
+ "kwargs": {"num_images": 1},
32
+ },
33
+ FLUX_DEV={
34
  "name": "FLUX.1 Dev",
35
+ "num_inference_steps": 28,
36
  "num_inference_steps_min": 10,
37
+ "num_inference_steps_max": 50,
38
  "guidance_scale": 3.5,
39
  "guidance_scale_min": 1.0,
40
+ "guidance_scale_max": 10.0,
41
  "parameters": ["width", "height", "guidance_scale", "num_inference_steps"],
42
  "kwargs": {"max_sequence_length": 512},
43
  },
44
+ FLUX_PRO={
45
+ "name": "FLUX.1 Pro",
46
+ "num_inference_steps": 28,
47
+ "num_inference_steps_min": 10,
48
+ "num_inference_steps_max": 50,
49
+ "guidance_scale": 3.5,
50
+ "guidance_scale_min": 1.0,
51
+ "guidance_scale_max": 10.0,
52
+ "parameters": ["seed", "image_size", "num_inference_steps", "guidance_scale"],
53
+ "kwargs": {"num_images": 1, "sync_mode": True, "safety_tolerance": 6},
54
+ },
55
+ FLUX_SCHNELL={
56
  "name": "FLUX.1 Schnell",
57
  "num_inference_steps": 4,
58
  "num_inference_steps_min": 1,
 
60
  "parameters": ["width", "height", "num_inference_steps"],
61
  "kwargs": {"guidance_scale": 0.0, "max_sequence_length": 256},
62
  },
63
+ FOOOCUS={
64
+ "name": "Fooocus",
65
+ "guidance_scale": 4.0,
66
+ "guidance_scale_min": 1.0,
67
+ "guidance_scale_max": 10.0,
68
+ "parameters": ["seed", "negative_prompt", "aspect_ratio", "guidance_scale"],
69
+ "kwargs": {
70
+ "num_images": 1,
71
+ "sync_mode": True,
72
+ "enable_safety_checker": False,
73
+ "output_format": "png",
74
+ "sharpness": 2,
75
+ "styles": ["Fooocus Enhance", "Fooocus V2", "Fooocus Sharp"],
76
+ "performance": "Quality",
77
+ },
78
+ },
79
+ KOLORS={
80
+ "name": "Kolors",
81
+ "guidance_scale": 5.0,
82
+ "guidance_scale_min": 1.0,
83
+ "guidance_scale_max": 10.0,
84
+ "num_inference_steps": 50,
85
+ "num_inference_steps_min": 10,
86
+ "num_inference_steps_max": 50,
87
+ "parameters": [
88
+ "seed",
89
+ "negative_prompt",
90
+ "image_size",
91
+ "guidance_scale",
92
+ "num_inference_steps",
93
+ ],
94
+ "kwargs": {
95
+ "num_images": 1,
96
+ "sync_mode": True,
97
+ "enable_safety_checker": False,
98
+ "scheduler": "EulerDiscreteScheduler",
99
+ },
100
+ },
101
+ PIXART_SIGMA={
102
+ "name": "PixArt-Σ",
103
+ "guidance_scale": 4.5,
104
+ "guidance_scale_min": 1.0,
105
+ "guidance_scale_max": 10.0,
106
+ "num_inference_steps": 35,
107
+ "num_inference_steps_min": 10,
108
+ "num_inference_steps_max": 50,
109
+ "parameters": ["seed", "negative_prompt", "image_size", "guidance_scale", "num_inference_steps"],
110
+ "kwargs": {
111
+ "num_images": 1,
112
+ "sync_mode": True,
113
+ "enable_safety_checker": False,
114
+ "style": "(No style)",
115
+ "scheduler": "SA-SOLVER",
116
+ },
117
+ },
118
+ STABLE_DIFFUSION_3={
119
+ "name": "SD3",
120
+ "guidance_scale": 5.0,
121
+ "guidance_scale_min": 1.0,
122
+ "guidance_scale_max": 10.0,
123
+ "num_inference_steps": 28,
124
+ "num_inference_steps_min": 10,
125
+ "num_inference_steps_max": 50,
126
+ "parameters": [
127
+ "seed",
128
+ "negative_prompt",
129
+ "image_size",
130
+ "guidance_scale",
131
+ "num_inference_steps",
132
+ "prompt_expansion",
133
+ ],
134
+ "kwargs": {"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
135
+ },
136
  STABLE_DIFFUSION_XL={
137
  "name": "SDXL",
138
  "guidance_scale": 7.0,
139
  "guidance_scale_min": 1.0,
140
+ "guidance_scale_max": 10.0,
141
  "num_inference_steps": 40,
142
  "num_inference_steps_min": 10,
143
  "num_inference_steps_max": 50,
144
+ "parameters": ["seed", "negative_prompt", "width", "height", "guidance_scale", "num_inference_steps"],
145
  },
146
  )
pages/2_🎨_Text_to_Image.py CHANGED
@@ -5,16 +5,16 @@ import streamlit as st
5
 
6
  from lib import Config, ModelPresets, txt2img_generate
7
 
8
- HF_TOKEN = None
9
- FAL_KEY = None
10
- # HF_TOKEN = os.environ.get("HF_TOKEN") or None
11
- # FAL_KEY = os.environ.get("FAL_KEY") or None
12
- API_URL = "https://api-inference.huggingface.co/models"
13
- HEADERS = {"Authorization": f"Bearer {HF_TOKEN}", "X-Wait-For-Model": "true", "X-Use-Cache": "false"}
14
  PRESET_MODEL = {
15
- "black-forest-labs/flux.1-dev": ModelPresets.FLUX_1_DEV,
16
- "black-forest-labs/flux.1-schnell": ModelPresets.FLUX_1_SCHNELL,
17
  "stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
 
 
 
 
18
  }
19
 
20
  # config
@@ -34,8 +34,8 @@ if "api_key_huggingface" not in st.session_state:
34
  if "txt2img_messages" not in st.session_state:
35
  st.session_state.txt2img_messages = []
36
 
37
- if "txt2img_running" not in st.session_state:
38
- st.session_state.txt2img_running = False
39
 
40
  if "txt2img_seed" not in st.session_state:
41
  st.session_state.txt2img_seed = 0
@@ -45,9 +45,9 @@ st.logo("logo.svg")
45
  st.sidebar.header("Settings")
46
  service = st.sidebar.selectbox(
47
  "Service",
48
- options=["Huggingface"],
49
- index=0,
50
- disabled=st.session_state.txt2img_running,
51
  )
52
 
53
  if service == "Huggingface" and HF_TOKEN is None:
@@ -56,7 +56,7 @@ if service == "Huggingface" and HF_TOKEN is None:
56
  type="password",
57
  help="Cleared on page refresh",
58
  value=st.session_state.api_key_huggingface,
59
- disabled=st.session_state.txt2img_running,
60
  )
61
  else:
62
  st.session_state.api_key_huggingface = st.session_state.api_key_huggingface
@@ -67,7 +67,7 @@ if service == "Fal" and FAL_KEY is None:
67
  type="password",
68
  help="Cleared on page refresh",
69
  value=st.session_state.api_key_fal,
70
- disabled=st.session_state.txt2img_running,
71
  )
72
  else:
73
  st.session_state.api_key_fal = st.session_state.api_key_fal
@@ -80,17 +80,11 @@ if service == "Fal" and FAL_KEY is not None:
80
 
81
  model = st.sidebar.selectbox(
82
  "Model",
83
- options=Config.TXT2IMG_MODELS,
84
- index=Config.TXT2IMG_DEFAULT_MODEL,
85
- disabled=st.session_state.txt2img_running,
86
  format_func=lambda x: x.split("/")[1],
87
  )
88
- aspect_ratio = st.sidebar.select_slider(
89
- "Aspect Ratio",
90
- options=list(Config.TXT2IMG_AR.keys()),
91
- value=Config.TXT2IMG_DEFAULT_AR,
92
- disabled=st.session_state.txt2img_running,
93
- )
94
 
95
  # heading
96
  st.html("""
@@ -102,10 +96,52 @@ st.html("""
102
  parameters = {}
103
  preset = PRESET_MODEL[model]
104
  for param in preset["parameters"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  if param == "width":
106
- parameters[param] = Config.TXT2IMG_AR[aspect_ratio][0]
 
 
 
 
 
 
 
107
  if param == "height":
108
- parameters[param] = Config.TXT2IMG_AR[aspect_ratio][1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  if param == "guidance_scale":
110
  parameters[param] = st.sidebar.slider(
111
  "Guidance Scale",
@@ -113,7 +149,7 @@ for param in preset["parameters"]:
113
  preset["guidance_scale_max"],
114
  preset["guidance_scale"],
115
  0.1,
116
- disabled=st.session_state.txt2img_running,
117
  )
118
  if param == "num_inference_steps":
119
  parameters[param] = st.sidebar.slider(
@@ -122,21 +158,19 @@ for param in preset["parameters"]:
122
  preset["num_inference_steps_max"],
123
  preset["num_inference_steps"],
124
  1,
125
- disabled=st.session_state.txt2img_running,
126
  )
127
- if param == "seed":
128
- parameters[param] = st.sidebar.number_input(
129
- "Seed",
130
- min_value=-1,
131
- max_value=(1 << 53) - 1,
132
- value=-1,
133
- disabled=st.session_state.txt2img_running,
134
  )
135
- if param == "negative_prompt":
136
- parameters[param] = st.sidebar.text_area(
137
- label="Negative Prompt",
138
- value=Config.TXT2IMG_NEGATIVE_PROMPT,
139
- disabled=st.session_state.txt2img_running,
140
  )
141
 
142
  # wrap the prompt in an expander to display additional parameters
@@ -154,8 +188,13 @@ for message in st.session_state.txt2img_messages:
154
  div[data-testid="stMarkdownContainer"] p:not(:last-of-type) { margin-bottom: 0 }
155
  </style>
156
  """)
 
 
 
 
 
157
  md = f"`model`: {message['model']}\n\n"
158
- md += "\n\n".join([f"`{k}`: {v}" for k, v in message["parameters"].items()])
159
  st.markdown(md)
160
 
161
  if role == "assistant":
@@ -192,7 +231,7 @@ if st.session_state.txt2img_messages:
192
  col1, col2 = st.columns(2)
193
  with col1:
194
  if (
195
- st.button("❌", help="Delete last generation", disabled=st.session_state.txt2img_running)
196
  and len(st.session_state.txt2img_messages) >= 2
197
  ):
198
  st.session_state.txt2img_messages.pop()
@@ -200,7 +239,7 @@ if st.session_state.txt2img_messages:
200
  st.rerun()
201
 
202
  with col2:
203
- if st.button("🗑️", help="Clear all generations", disabled=st.session_state.txt2img_running):
204
  st.session_state.txt2img_messages = []
205
  st.session_state.txt2img_seed = 0
206
  st.rerun()
@@ -210,7 +249,7 @@ else:
210
  # show the prompt and spinner while loading then update state and re-render
211
  if prompt := st.chat_input(
212
  "What do you want to see?",
213
- on_submit=lambda: setattr(st.session_state, "txt2img_running", True),
214
  ):
215
  if "seed" in parameters and parameters["seed"] >= 0:
216
  st.session_state.txt2img_seed = parameters["seed"]
@@ -231,7 +270,7 @@ if prompt := st.chat_input(
231
  parameters.update(preset["kwargs"])
232
  api_key = getattr(st.session_state, f"api_key_{service.lower()}", None)
233
  image = txt2img_generate(api_key, service, model, prompt, parameters)
234
- st.session_state.txt2img_running = False
235
 
236
  model_name = PRESET_MODEL[model]["name"]
237
  st.session_state.txt2img_messages.append(
 
5
 
6
  from lib import Config, ModelPresets, txt2img_generate
7
 
8
+ HF_TOKEN = os.environ.get("HF_TOKEN") or None
9
+ FAL_KEY = os.environ.get("FAL_KEY") or None
 
 
 
 
10
  PRESET_MODEL = {
11
+ "black-forest-labs/flux.1-dev": ModelPresets.FLUX_DEV,
12
+ "black-forest-labs/flux.1-schnell": ModelPresets.FLUX_SCHNELL,
13
  "stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
14
+ "fal-ai/fooocus": ModelPresets.FOOOCUS,
15
+ "fal-ai/kolors": ModelPresets.KOLORS,
16
+ "fal-ai/pixart-sigma": ModelPresets.PIXART_SIGMA,
17
+ "fal-ai/stable-diffusion-v3-medium": ModelPresets.STABLE_DIFFUSION_3,
18
  }
19
 
20
  # config
 
34
  if "txt2img_messages" not in st.session_state:
35
  st.session_state.txt2img_messages = []
36
 
37
+ if "running" not in st.session_state:
38
+ st.session_state.running = False
39
 
40
  if "txt2img_seed" not in st.session_state:
41
  st.session_state.txt2img_seed = 0
 
45
  st.sidebar.header("Settings")
46
  service = st.sidebar.selectbox(
47
  "Service",
48
+ options=["Fal", "Huggingface"],
49
+ index=1,
50
+ disabled=st.session_state.running,
51
  )
52
 
53
  if service == "Huggingface" and HF_TOKEN is None:
 
56
  type="password",
57
  help="Cleared on page refresh",
58
  value=st.session_state.api_key_huggingface,
59
+ disabled=st.session_state.running,
60
  )
61
  else:
62
  st.session_state.api_key_huggingface = st.session_state.api_key_huggingface
 
67
  type="password",
68
  help="Cleared on page refresh",
69
  value=st.session_state.api_key_fal,
70
+ disabled=st.session_state.running,
71
  )
72
  else:
73
  st.session_state.api_key_fal = st.session_state.api_key_fal
 
80
 
81
  model = st.sidebar.selectbox(
82
  "Model",
83
+ options=Config.TXT2IMG_MODELS[service],
84
+ index=Config.TXT2IMG_DEFAULT_MODEL[service],
85
+ disabled=st.session_state.running,
86
  format_func=lambda x: x.split("/")[1],
87
  )
 
 
 
 
 
 
88
 
89
  # heading
90
  st.html("""
 
96
  parameters = {}
97
  preset = PRESET_MODEL[model]
98
  for param in preset["parameters"]:
99
+ if param == "seed":
100
+ parameters[param] = st.sidebar.number_input(
101
+ "Seed",
102
+ min_value=-1,
103
+ max_value=(1 << 53) - 1,
104
+ value=-1,
105
+ disabled=st.session_state.running,
106
+ )
107
+ if param == "negative_prompt":
108
+ parameters[param] = st.sidebar.text_area(
109
+ "Negative Prompt",
110
+ value=Config.TXT2IMG_NEGATIVE_PROMPT,
111
+ disabled=st.session_state.running,
112
+ )
113
  if param == "width":
114
+ parameters[param] = st.sidebar.slider(
115
+ "Width",
116
+ step=64,
117
+ value=1024,
118
+ min_value=512,
119
+ max_value=2048,
120
+ disabled=st.session_state.running,
121
+ )
122
  if param == "height":
123
+ parameters[param] = st.sidebar.slider(
124
+ "Height",
125
+ step=64,
126
+ value=1024,
127
+ min_value=512,
128
+ max_value=2048,
129
+ disabled=st.session_state.running,
130
+ )
131
+ if param == "image_size":
132
+ parameters[param] = st.sidebar.select_slider(
133
+ "Image Size",
134
+ options=Config.TXT2IMG_IMAGE_SIZES,
135
+ value=Config.TXT2IMG_DEFAULT_IMAGE_SIZE,
136
+ disabled=st.session_state.running,
137
+ )
138
+ if param == "aspect_ratio":
139
+ parameters[param] = st.sidebar.select_slider(
140
+ "Aspect Ratio",
141
+ options=Config.TXT2IMG_ASPECT_RATIOS,
142
+ value=Config.TXT2IMG_DEFAULT_ASPECT_RATIO,
143
+ disabled=st.session_state.running,
144
+ )
145
  if param == "guidance_scale":
146
  parameters[param] = st.sidebar.slider(
147
  "Guidance Scale",
 
149
  preset["guidance_scale_max"],
150
  preset["guidance_scale"],
151
  0.1,
152
+ disabled=st.session_state.running,
153
  )
154
  if param == "num_inference_steps":
155
  parameters[param] = st.sidebar.slider(
 
158
  preset["num_inference_steps_max"],
159
  preset["num_inference_steps"],
160
  1,
161
+ disabled=st.session_state.running,
162
  )
163
+ if param == "expand_prompt":
164
+ parameters[param] = st.sidebar.checkbox(
165
+ "Expand Prompt",
166
+ value=False,
167
+ disabled=st.session_state.running,
 
 
168
  )
169
+ if param == "prompt_expansion":
170
+ parameters[param] = st.sidebar.checkbox(
171
+ "Prompt Expansion",
172
+ value=False,
173
+ disabled=st.session_state.running,
174
  )
175
 
176
  # wrap the prompt in an expander to display additional parameters
 
188
  div[data-testid="stMarkdownContainer"] p:not(:last-of-type) { margin-bottom: 0 }
189
  </style>
190
  """)
191
+ filtered_parameters = {
192
+ k: v
193
+ for k, v in message["parameters"].items()
194
+ if k not in Config.TXT2IMG_HIDDEN_PARAMETERS
195
+ }
196
  md = f"`model`: {message['model']}\n\n"
197
+ md += "\n\n".join([f"`{k}`: {v}" for k, v in filtered_parameters.items()])
198
  st.markdown(md)
199
 
200
  if role == "assistant":
 
231
  col1, col2 = st.columns(2)
232
  with col1:
233
  if (
234
+ st.button("❌", help="Delete last generation", disabled=st.session_state.running)
235
  and len(st.session_state.txt2img_messages) >= 2
236
  ):
237
  st.session_state.txt2img_messages.pop()
 
239
  st.rerun()
240
 
241
  with col2:
242
+ if st.button("🗑️", help="Clear all generations", disabled=st.session_state.running):
243
  st.session_state.txt2img_messages = []
244
  st.session_state.txt2img_seed = 0
245
  st.rerun()
 
249
  # show the prompt and spinner while loading then update state and re-render
250
  if prompt := st.chat_input(
251
  "What do you want to see?",
252
+ on_submit=lambda: setattr(st.session_state, "running", True),
253
  ):
254
  if "seed" in parameters and parameters["seed"] >= 0:
255
  st.session_state.txt2img_seed = parameters["seed"]
 
270
  parameters.update(preset["kwargs"])
271
  api_key = getattr(st.session_state, f"api_key_{service.lower()}", None)
272
  image = txt2img_generate(api_key, service, model, prompt, parameters)
273
+ st.session_state.running = False
274
 
275
  model_name = PRESET_MODEL[model]["name"]
276
  st.session_state.txt2img_messages.append(