adamelliotfields commited on
Commit
a5515e4
1 Parent(s): 297482a

Better environment variable handling

Browse files
lib/api.py CHANGED
@@ -11,7 +11,7 @@ from .config import Config
11
 
12
  def txt2txt_generate(api_key, service, model, parameters, **kwargs):
13
  base_url = Config.SERVICES[service]
14
- if service == "Huggingface":
15
  base_url = f"{base_url}/{model}/v1"
16
  client = OpenAI(api_key=api_key, base_url=base_url)
17
 
@@ -19,14 +19,15 @@ def txt2txt_generate(api_key, service, model, parameters, **kwargs):
19
  stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
20
  return st.write_stream(stream)
21
  except APIError as e:
22
- return e.message
 
23
  except Exception as e:
24
  return str(e)
25
 
26
 
27
  def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
28
  headers = {}
29
- if service == "Huggingface":
30
  headers["Authorization"] = f"Bearer {api_key}"
31
  headers["X-Wait-For-Model"] = "true"
32
  headers["X-Use-Cache"] = "false"
@@ -34,7 +35,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
34
  headers["Authorization"] = f"Key {api_key}"
35
 
36
  json = {}
37
- if service == "Huggingface":
38
  json = {
39
  "inputs": inputs,
40
  "parameters": {**parameters, **kwargs},
@@ -48,7 +49,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
48
  try:
49
  response = requests.post(base_url, headers=headers, json=json)
50
  if response.status_code // 100 == 2: # 2xx
51
- if service == "Huggingface":
52
  return Image.open(io.BytesIO(response.content))
53
  if service == "Fal":
54
  bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
 
11
 
12
  def txt2txt_generate(api_key, service, model, parameters, **kwargs):
13
  base_url = Config.SERVICES[service]
14
+ if service == "Hugging Face":
15
  base_url = f"{base_url}/{model}/v1"
16
  client = OpenAI(api_key=api_key, base_url=base_url)
17
 
 
19
  stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
20
  return st.write_stream(stream)
21
  except APIError as e:
22
+ # OpenAI uses this message for streaming errors and attaches response.error to error.body
23
+ return e.body if e.message == "An error occurred during streaming" else e.message
24
  except Exception as e:
25
  return str(e)
26
 
27
 
28
  def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
29
  headers = {}
30
+ if service == "Hugging Face":
31
  headers["Authorization"] = f"Bearer {api_key}"
32
  headers["X-Wait-For-Model"] = "true"
33
  headers["X-Use-Cache"] = "false"
 
35
  headers["Authorization"] = f"Key {api_key}"
36
 
37
  json = {}
38
+ if service == "Hugging Face":
39
  json = {
40
  "inputs": inputs,
41
  "parameters": {**parameters, **kwargs},
 
49
  try:
50
  response = requests.post(base_url, headers=headers, json=json)
51
  if response.status_code // 100 == 2: # 2xx
52
+ if service == "Hugging Face":
53
  return Image.open(io.BytesIO(response.content))
54
  if service == "Fal":
55
  bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
lib/config.py CHANGED
@@ -5,11 +5,12 @@ Config = SimpleNamespace(
5
  ICON="⚡",
6
  LAYOUT="wide",
7
  SERVICES={
8
- "Huggingface": "https://api-inference.huggingface.co/models",
9
  "Perplexity": "https://api.perplexity.ai",
10
  "Fal": "https://fal.run",
11
  },
12
  TXT2IMG_HIDDEN_PARAMETERS=[
 
13
  "enable_safety_checker",
14
  "max_sequence_length",
15
  "num_images",
@@ -22,18 +23,20 @@ Config = SimpleNamespace(
22
  "styles",
23
  "sync_mode",
24
  ],
25
- TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, malformed, mutated, disgusting, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
26
  TXT2IMG_DEFAULT_MODEL={
27
- "Huggingface": 2,
 
28
  "Fal": 1,
29
  },
30
  TXT2IMG_MODELS={
31
- "Huggingface": [
32
  "black-forest-labs/flux.1-dev",
33
  "black-forest-labs/flux.1-schnell",
34
  "stabilityai/stable-diffusion-xl-base-1.0",
35
  ],
36
  "Fal": [
 
37
  # "fal-ai/aura-flow",
38
  # "fal-ai/flux-pro",
39
  "fal-ai/fooocus",
@@ -52,40 +55,36 @@ Config = SimpleNamespace(
52
  ],
53
  TXT2IMG_DEFAULT_ASPECT_RATIO="1024x1024", # fooocus aspect ratios
54
  TXT2IMG_ASPECT_RATIOS=[
55
- "704x1408",
56
- "704x1344",
57
- "768x1344",
58
- "768x1280",
59
- "832x1216",
60
- "832x1152",
61
- "896x1152",
62
- "896x1088",
63
- "960x1088",
64
- "960x1024",
65
  "1024x1024",
66
- "1024x960",
67
- "1088x960",
68
- "1088x896",
69
- "1152x896",
70
- "1152x832",
71
- "1216x832",
72
- "1280x768",
73
- "1344x768",
74
- "1344x704",
75
- "1408x704",
76
- "1472x704",
77
- "1536x640",
78
- "1600x640",
79
- "1664x576",
80
- "1728x576",
81
  ],
 
82
  TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
83
  TXT2TXT_DEFAULT_MODEL={
84
- "Huggingface": 4,
85
  "Perplexity": 3,
86
  },
87
  TXT2TXT_MODELS={
88
- "Huggingface": [
89
  "codellama/codellama-34b-instruct-hf",
90
  "meta-llama/llama-2-13b-chat-hf",
91
  "meta-llama/meta-llama-3.1-405b-instruct-fp8",
 
5
  ICON="⚡",
6
  LAYOUT="wide",
7
  SERVICES={
8
+ "Hugging Face": "https://api-inference.huggingface.co/models",
9
  "Perplexity": "https://api.perplexity.ai",
10
  "Fal": "https://fal.run",
11
  },
12
  TXT2IMG_HIDDEN_PARAMETERS=[
13
+ # sent to API but not shown in generation parameters accordion
14
  "enable_safety_checker",
15
  "max_sequence_length",
16
  "num_images",
 
23
  "styles",
24
  "sync_mode",
25
  ],
26
+ TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, disfigured, deformed, mutated, malformed, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
27
  TXT2IMG_DEFAULT_MODEL={
28
+ # index of model in below lists
29
+ "Hugging Face": 2,
30
  "Fal": 1,
31
  },
32
  TXT2IMG_MODELS={
33
+ "Hugging Face": [
34
  "black-forest-labs/flux.1-dev",
35
  "black-forest-labs/flux.1-schnell",
36
  "stabilityai/stable-diffusion-xl-base-1.0",
37
  ],
38
  "Fal": [
39
+ # TODO: fix these models
40
  # "fal-ai/aura-flow",
41
  # "fal-ai/flux-pro",
42
  "fal-ai/fooocus",
 
55
  ],
56
  TXT2IMG_DEFAULT_ASPECT_RATIO="1024x1024", # fooocus aspect ratios
57
  TXT2IMG_ASPECT_RATIOS=[
58
+ "704x1408", # 1:2
59
+ "704x1344", # 11:21
60
+ "768x1344", # 4:7
61
+ "768x1280", # 3:5
62
+ "832x1216", # 13:19
63
+ "832x1152", # 13:18
64
+ "896x1152", # 7:9
65
+ "896x1088", # 14:17
66
+ "960x1088", # 15:17
67
+ "960x1024", # 15:16
68
  "1024x1024",
69
+ "1024x960", # 16:15
70
+ "1088x960", # 17:15
71
+ "1088x896", # 17:14
72
+ "1152x896", # 9:7
73
+ "1152x832", # 18:13
74
+ "1216x832", # 19:13
75
+ "1280x768", # 5:3
76
+ "1344x768", # 7:4
77
+ "1344x704", # 21:11
78
+ "1408x704", # 2:1
 
 
 
 
 
79
  ],
80
+ # TODO: txt2img fooocus styles like "Fooocus V2" and "Fooocus Enhance" (use multiselect in UI)
81
  TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
82
  TXT2TXT_DEFAULT_MODEL={
83
+ "Hugging Face": 4,
84
  "Perplexity": 3,
85
  },
86
  TXT2TXT_MODELS={
87
+ "Hugging Face": [
88
  "codellama/codellama-34b-instruct-hf",
89
  "meta-llama/llama-2-13b-chat-hf",
90
  "meta-llama/meta-llama-3.1-405b-instruct-fp8",
lib/presets.py CHANGED
@@ -2,14 +2,14 @@ from types import SimpleNamespace
2
 
3
  # txt2txt services
4
  ServicePresets = SimpleNamespace(
5
- Huggingface={
6
  # every service has model and system messages
7
  "frequency_penalty": 0.0,
8
  "frequency_penalty_min": -2.0,
9
  "frequency_penalty_max": 2.0,
10
  "parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
11
  },
12
- Perplexity={
13
  "frequency_penalty": 1.0,
14
  "frequency_penalty_min": 1.0,
15
  "frequency_penalty_max": 2.0,
 
2
 
3
  # txt2txt services
4
  ServicePresets = SimpleNamespace(
5
+ HUGGING_FACE={
6
  # every service has model and system messages
7
  "frequency_penalty": 0.0,
8
  "frequency_penalty_min": -2.0,
9
  "frequency_penalty_max": 2.0,
10
  "parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
11
  },
12
+ PERPLEXITY={
13
  "frequency_penalty": 1.0,
14
  "frequency_penalty_min": 1.0,
15
  "frequency_penalty_max": 2.0,
pages/1_💬_Text_Generation.py CHANGED
@@ -5,8 +5,15 @@ import streamlit as st
5
 
6
  from lib import Config, ServicePresets, txt2txt_generate
7
 
8
- HF_TOKEN = os.environ.get("HF_TOKEN") or None
9
- PERPLEXITY_API_KEY = os.environ.get("PERPLEXITY_API_KEY") or None
 
 
 
 
 
 
 
10
 
11
  # config
12
  st.set_page_config(
@@ -16,8 +23,8 @@ st.set_page_config(
16
  )
17
 
18
  # initialize state
19
- if "api_key_huggingface" not in st.session_state:
20
- st.session_state.api_key_huggingface = ""
21
 
22
  if "api_key_perplexity" not in st.session_state:
23
  st.session_state.api_key_perplexity = ""
@@ -28,53 +35,36 @@ if "running" not in st.session_state:
28
  if "txt2txt_messages" not in st.session_state:
29
  st.session_state.txt2txt_messages = []
30
 
31
- if "txt2txt_prompt" not in st.session_state:
32
- st.session_state.txt2txt_prompt = ""
33
 
34
  # sidebar
35
  st.logo("logo.svg")
36
  st.sidebar.header("Settings")
37
  service = st.sidebar.selectbox(
38
  "Service",
39
- options=["Huggingface", "Perplexity"],
40
  index=0,
41
  disabled=st.session_state.running,
42
  )
43
 
44
- if service == "Huggingface" and HF_TOKEN is None:
45
- st.session_state.api_key_huggingface = st.sidebar.text_input(
46
- "API Key",
47
- type="password",
48
- help="Cleared on page refresh",
49
- disabled=st.session_state.running,
50
- value=st.session_state.api_key_huggingface,
51
- )
52
- else:
53
- st.session_state.api_key_huggingface = st.session_state.api_key_huggingface
54
-
55
- if service == "Perplexity" and PERPLEXITY_API_KEY is None:
56
- st.session_state.api_key_perplexity = st.sidebar.text_input(
57
- "API Key",
58
- type="password",
59
- help="Cleared on page refresh",
60
- disabled=st.session_state.running,
61
- value=st.session_state.api_key_perplexity,
62
- )
63
- else:
64
- st.session_state.api_key_perplexity = st.session_state.api_key_perplexity
65
-
66
- if service == "Huggingface" and HF_TOKEN is not None:
67
- st.session_state.api_key_huggingface = HF_TOKEN
68
-
69
- if service == "Perplexity" and PERPLEXITY_API_KEY is not None:
70
- st.session_state.api_key_perplexity = PERPLEXITY_API_KEY
71
 
72
  model = st.sidebar.selectbox(
73
  "Model",
74
  options=Config.TXT2TXT_MODELS[service],
75
  index=Config.TXT2TXT_DEFAULT_MODEL[service],
76
  disabled=st.session_state.running,
77
- format_func=lambda x: x.split("/")[1] if service == "Huggingface" else x,
78
  )
79
  system = st.sidebar.text_area(
80
  "System Message",
@@ -84,7 +74,8 @@ system = st.sidebar.text_area(
84
 
85
  # build parameters from preset
86
  parameters = {}
87
- preset = getattr(ServicePresets, service, {})
 
88
  for param in preset["parameters"]:
89
  if param == "max_tokens":
90
  parameters[param] = st.sidebar.slider(
@@ -152,26 +143,15 @@ if st.session_state.txt2txt_messages:
152
  </style>
153
  """)
154
 
155
- # remove last assistant message and resend prompt
156
- col1, col2, col3 = st.columns(3)
157
  with col1:
158
- if st.button("🔄️", help="Retry last message") and len(st.session_state.txt2txt_messages) >= 2:
159
- st.session_state.txt2txt_messages.pop()
160
- st.session_state.txt2txt_prompt = st.session_state.txt2txt_messages.pop()["content"]
161
- st.rerun()
162
-
163
- # delete last message pair
164
- with col2:
165
  if st.button("❌", help="Delete last message") and len(st.session_state.txt2txt_messages) >= 2:
166
  st.session_state.txt2txt_messages.pop()
167
  st.session_state.txt2txt_messages.pop()
168
  st.rerun()
169
-
170
- # reset app state
171
- with col3:
172
  if st.button("🗑️", help="Clear all messages"):
173
  st.session_state.txt2txt_messages = []
174
- st.session_state.txt2txt_prompt = ""
175
  st.rerun()
176
  else:
177
  button_container = None
@@ -181,28 +161,30 @@ if prompt := st.chat_input(
181
  "What would you like to know?",
182
  on_submit=lambda: setattr(st.session_state, "running", True),
183
  ):
184
- st.session_state.txt2txt_prompt = prompt
185
-
186
- if parameters.get("seed", 0) < 0:
187
- parameters["seed"] = int(datetime.now().timestamp() * 1e6) % (1 << 53)
 
 
188
 
189
  if button_container:
190
  button_container.empty()
191
 
192
  messages = [{"role": "system", "content": system}]
193
  messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
194
- messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
195
  parameters["messages"] = messages
196
 
197
  with st.chat_message("user"):
198
- st.markdown(st.session_state.txt2txt_prompt)
199
 
200
  with st.chat_message("assistant"):
201
- api_key = getattr(st.session_state, f"api_key_{service.lower()}", None)
 
202
  response = txt2txt_generate(api_key, service, model, parameters)
203
  st.session_state.running = False
204
 
205
- st.session_state.txt2txt_messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
206
  st.session_state.txt2txt_messages.append({"role": "assistant", "content": response})
207
- st.session_state.txt2txt_prompt = ""
208
  st.rerun()
 
5
 
6
  from lib import Config, ServicePresets, txt2txt_generate
7
 
8
+ SERVICE_SESSION = {
9
+ "Hugging Face": "api_key_hugging_face",
10
+ "Perplexity": "api_key_perplexity",
11
+ }
12
+
13
+ SESSION_TOKEN = {
14
+ "api_key_hugging_face": os.environ.get("HF_TOKEN") or None,
15
+ "api_key_perplexity": os.environ.get("PERPLEXITY_API_KEY") or None,
16
+ }
17
 
18
  # config
19
  st.set_page_config(
 
23
  )
24
 
25
  # initialize state
26
+ if "api_key_hugging_face" not in st.session_state:
27
+ st.session_state.api_key_hugging_face = ""
28
 
29
  if "api_key_perplexity" not in st.session_state:
30
  st.session_state.api_key_perplexity = ""
 
35
  if "txt2txt_messages" not in st.session_state:
36
  st.session_state.txt2txt_messages = []
37
 
38
+ if "txt2txt_seed" not in st.session_state:
39
+ st.session_state.txt2txt_seed = 0
40
 
41
  # sidebar
42
  st.logo("logo.svg")
43
  st.sidebar.header("Settings")
44
  service = st.sidebar.selectbox(
45
  "Service",
46
+ options=["Hugging Face", "Perplexity"],
47
  index=0,
48
  disabled=st.session_state.running,
49
  )
50
 
51
+ # disable API key input and hide value if set by environment variable (handle empty string value later)
52
+ for display_name, session_key in SERVICE_SESSION.items():
53
+ if service == display_name:
54
+ st.session_state[session_key] = st.sidebar.text_input(
55
+ "API Key",
56
+ type="password",
57
+ value="" if SESSION_TOKEN[session_key] else st.session_state[session_key],
58
+ disabled=bool(SESSION_TOKEN[session_key]) or st.session_state.running,
59
+ help="Set by environment variable" if SESSION_TOKEN[session_key] else "Cleared on page refresh",
60
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  model = st.sidebar.selectbox(
63
  "Model",
64
  options=Config.TXT2TXT_MODELS[service],
65
  index=Config.TXT2TXT_DEFAULT_MODEL[service],
66
  disabled=st.session_state.running,
67
+ format_func=lambda x: x.split("/")[1] if service == "Hugging Face" else x,
68
  )
69
  system = st.sidebar.text_area(
70
  "System Message",
 
74
 
75
  # build parameters from preset
76
  parameters = {}
77
+ service_key = service.upper().replace(" ", "_")
78
+ preset = getattr(ServicePresets, service_key, {})
79
  for param in preset["parameters"]:
80
  if param == "max_tokens":
81
  parameters[param] = st.sidebar.slider(
 
143
  </style>
144
  """)
145
 
146
+ col1, col2 = st.columns(2)
 
147
  with col1:
 
 
 
 
 
 
 
148
  if st.button("❌", help="Delete last message") and len(st.session_state.txt2txt_messages) >= 2:
149
  st.session_state.txt2txt_messages.pop()
150
  st.session_state.txt2txt_messages.pop()
151
  st.rerun()
152
+ with col2:
 
 
153
  if st.button("🗑️", help="Clear all messages"):
154
  st.session_state.txt2txt_messages = []
 
155
  st.rerun()
156
  else:
157
  button_container = None
 
161
  "What would you like to know?",
162
  on_submit=lambda: setattr(st.session_state, "running", True),
163
  ):
164
+ if "seed" in parameters and parameters["seed"] >= 0:
165
+ st.session_state.txt2txt_seed = parameters["seed"]
166
+ else:
167
+ st.session_state.txt2txt_seed = int(datetime.now().timestamp() * 1e6) % (1 << 53)
168
+ if "seed" in parameters:
169
+ parameters["seed"] = st.session_state.txt2txt_seed
170
 
171
  if button_container:
172
  button_container.empty()
173
 
174
  messages = [{"role": "system", "content": system}]
175
  messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
176
+ messages.append({"role": "user", "content": prompt})
177
  parameters["messages"] = messages
178
 
179
  with st.chat_message("user"):
180
+ st.markdown(prompt)
181
 
182
  with st.chat_message("assistant"):
183
+ session_key = f"api_key_{service.lower().replace(' ', '_')}"
184
+ api_key = st.session_state[session_key] or SESSION_TOKEN[session_key]
185
  response = txt2txt_generate(api_key, service, model, parameters)
186
  st.session_state.running = False
187
 
188
+ st.session_state.txt2txt_messages.append({"role": "user", "content": prompt})
189
  st.session_state.txt2txt_messages.append({"role": "assistant", "content": response})
 
190
  st.rerun()
pages/2_🎨_Text_to_Image.py CHANGED
@@ -5,8 +5,16 @@ import streamlit as st
5
 
6
  from lib import Config, ModelPresets, txt2img_generate
7
 
8
- HF_TOKEN = os.environ.get("HF_TOKEN") or None
9
- FAL_KEY = os.environ.get("FAL_KEY") or None
 
 
 
 
 
 
 
 
10
  PRESET_MODEL = {
11
  "black-forest-labs/flux.1-dev": ModelPresets.FLUX_DEV,
12
  "black-forest-labs/flux.1-schnell": ModelPresets.FLUX_SCHNELL,
@@ -28,15 +36,15 @@ st.set_page_config(
28
  if "api_key_fal" not in st.session_state:
29
  st.session_state.api_key_fal = ""
30
 
31
- if "api_key_huggingface" not in st.session_state:
32
- st.session_state.api_key_huggingface = ""
33
-
34
- if "txt2img_messages" not in st.session_state:
35
- st.session_state.txt2img_messages = []
36
 
37
  if "running" not in st.session_state:
38
  st.session_state.running = False
39
 
 
 
 
40
  if "txt2img_seed" not in st.session_state:
41
  st.session_state.txt2img_seed = 0
42
 
@@ -45,38 +53,21 @@ st.logo("logo.svg")
45
  st.sidebar.header("Settings")
46
  service = st.sidebar.selectbox(
47
  "Service",
48
- options=["Fal", "Huggingface"],
49
  index=1,
50
  disabled=st.session_state.running,
51
  )
52
 
53
- if service == "Huggingface" and HF_TOKEN is None:
54
- st.session_state.api_key_huggingface = st.sidebar.text_input(
55
- "API Key",
56
- type="password",
57
- help="Cleared on page refresh",
58
- value=st.session_state.api_key_huggingface,
59
- disabled=st.session_state.running,
60
- )
61
- else:
62
- st.session_state.api_key_huggingface = st.session_state.api_key_huggingface
63
-
64
- if service == "Fal" and FAL_KEY is None:
65
- st.session_state.api_key_fal = st.sidebar.text_input(
66
- "API Key",
67
- type="password",
68
- help="Cleared on page refresh",
69
- value=st.session_state.api_key_fal,
70
- disabled=st.session_state.running,
71
- )
72
- else:
73
- st.session_state.api_key_fal = st.session_state.api_key_fal
74
-
75
- if service == "Huggingface" and HF_TOKEN is not None:
76
- st.session_state.api_key_huggingface = HF_TOKEN
77
-
78
- if service == "Fal" and FAL_KEY is not None:
79
- st.session_state.api_key_fal = FAL_KEY
80
 
81
  model = st.sidebar.selectbox(
82
  "Model",
@@ -227,7 +218,6 @@ if st.session_state.txt2img_messages:
227
  </style>
228
  """)
229
 
230
- # retry
231
  col1, col2 = st.columns(2)
232
  with col1:
233
  if (
@@ -268,7 +258,8 @@ if prompt := st.chat_input(
268
  with st.spinner("Running..."):
269
  if preset.get("kwargs") is not None:
270
  parameters.update(preset["kwargs"])
271
- api_key = getattr(st.session_state, f"api_key_{service.lower()}", None)
 
272
  image = txt2img_generate(api_key, service, model, prompt, parameters)
273
  st.session_state.running = False
274
 
 
5
 
6
  from lib import Config, ModelPresets, txt2img_generate
7
 
8
+ SERVICE_SESSION = {
9
+ "Fal": "api_key_fal",
10
+ "Hugging Face": "api_key_hugging_face",
11
+ }
12
+
13
+ SESSION_TOKEN = {
14
+ "api_key_fal": os.environ.get("FAL_KEY") or None,
15
+ "api_key_hugging_face": os.environ.get("HF_TOKEN") or None,
16
+ }
17
+
18
  PRESET_MODEL = {
19
  "black-forest-labs/flux.1-dev": ModelPresets.FLUX_DEV,
20
  "black-forest-labs/flux.1-schnell": ModelPresets.FLUX_SCHNELL,
 
36
  if "api_key_fal" not in st.session_state:
37
  st.session_state.api_key_fal = ""
38
 
39
+ if "api_key_hugging_face" not in st.session_state:
40
+ st.session_state.api_key_hugging_face = ""
 
 
 
41
 
42
  if "running" not in st.session_state:
43
  st.session_state.running = False
44
 
45
+ if "txt2img_messages" not in st.session_state:
46
+ st.session_state.txt2img_messages = []
47
+
48
  if "txt2img_seed" not in st.session_state:
49
  st.session_state.txt2img_seed = 0
50
 
 
53
  st.sidebar.header("Settings")
54
  service = st.sidebar.selectbox(
55
  "Service",
56
+ options=list(SERVICE_SESSION.keys()),
57
  index=1,
58
  disabled=st.session_state.running,
59
  )
60
 
61
+ # disable API key input and hide value if set by environment variable (handle empty string value later)
62
+ for display_name, session_key in SERVICE_SESSION.items():
63
+ if service == display_name:
64
+ st.session_state[session_key] = st.sidebar.text_input(
65
+ "API Key",
66
+ type="password",
67
+ value="" if SESSION_TOKEN[session_key] else st.session_state[session_key],
68
+ disabled=bool(SESSION_TOKEN[session_key]) or st.session_state.running,
69
+ help="Set by environment variable" if SESSION_TOKEN[session_key] else "Cleared on page refresh",
70
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  model = st.sidebar.selectbox(
73
  "Model",
 
218
  </style>
219
  """)
220
 
 
221
  col1, col2 = st.columns(2)
222
  with col1:
223
  if (
 
258
  with st.spinner("Running..."):
259
  if preset.get("kwargs") is not None:
260
  parameters.update(preset["kwargs"])
261
+ session_key = f"api_key_{service.lower().replace(' ', '_')}"
262
+ api_key = st.session_state[session_key] or SESSION_TOKEN[session_key]
263
  image = txt2img_generate(api_key, service, model, prompt, parameters)
264
  st.session_state.running = False
265