adamelliotfields
commited on
Commit
•
a5515e4
1
Parent(s):
297482a
Better environment variable handling
Browse files- lib/api.py +6 -5
- lib/config.py +30 -31
- lib/presets.py +2 -2
- pages/1_💬_Text_Generation.py +40 -58
- pages/2_🎨_Text_to_Image.py +28 -37
lib/api.py
CHANGED
@@ -11,7 +11,7 @@ from .config import Config
|
|
11 |
|
12 |
def txt2txt_generate(api_key, service, model, parameters, **kwargs):
|
13 |
base_url = Config.SERVICES[service]
|
14 |
-
if service == "
|
15 |
base_url = f"{base_url}/{model}/v1"
|
16 |
client = OpenAI(api_key=api_key, base_url=base_url)
|
17 |
|
@@ -19,14 +19,15 @@ def txt2txt_generate(api_key, service, model, parameters, **kwargs):
|
|
19 |
stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
|
20 |
return st.write_stream(stream)
|
21 |
except APIError as e:
|
22 |
-
|
|
|
23 |
except Exception as e:
|
24 |
return str(e)
|
25 |
|
26 |
|
27 |
def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
28 |
headers = {}
|
29 |
-
if service == "
|
30 |
headers["Authorization"] = f"Bearer {api_key}"
|
31 |
headers["X-Wait-For-Model"] = "true"
|
32 |
headers["X-Use-Cache"] = "false"
|
@@ -34,7 +35,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
34 |
headers["Authorization"] = f"Key {api_key}"
|
35 |
|
36 |
json = {}
|
37 |
-
if service == "
|
38 |
json = {
|
39 |
"inputs": inputs,
|
40 |
"parameters": {**parameters, **kwargs},
|
@@ -48,7 +49,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
48 |
try:
|
49 |
response = requests.post(base_url, headers=headers, json=json)
|
50 |
if response.status_code // 100 == 2: # 2xx
|
51 |
-
if service == "
|
52 |
return Image.open(io.BytesIO(response.content))
|
53 |
if service == "Fal":
|
54 |
bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
|
|
|
11 |
|
12 |
def txt2txt_generate(api_key, service, model, parameters, **kwargs):
|
13 |
base_url = Config.SERVICES[service]
|
14 |
+
if service == "Hugging Face":
|
15 |
base_url = f"{base_url}/{model}/v1"
|
16 |
client = OpenAI(api_key=api_key, base_url=base_url)
|
17 |
|
|
|
19 |
stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
|
20 |
return st.write_stream(stream)
|
21 |
except APIError as e:
|
22 |
+
# OpenAI uses this message for streaming errors and attaches response.error to error.body
|
23 |
+
return e.body if e.message == "An error occurred during streaming" else e.message
|
24 |
except Exception as e:
|
25 |
return str(e)
|
26 |
|
27 |
|
28 |
def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
29 |
headers = {}
|
30 |
+
if service == "Hugging Face":
|
31 |
headers["Authorization"] = f"Bearer {api_key}"
|
32 |
headers["X-Wait-For-Model"] = "true"
|
33 |
headers["X-Use-Cache"] = "false"
|
|
|
35 |
headers["Authorization"] = f"Key {api_key}"
|
36 |
|
37 |
json = {}
|
38 |
+
if service == "Hugging Face":
|
39 |
json = {
|
40 |
"inputs": inputs,
|
41 |
"parameters": {**parameters, **kwargs},
|
|
|
49 |
try:
|
50 |
response = requests.post(base_url, headers=headers, json=json)
|
51 |
if response.status_code // 100 == 2: # 2xx
|
52 |
+
if service == "Hugging Face":
|
53 |
return Image.open(io.BytesIO(response.content))
|
54 |
if service == "Fal":
|
55 |
bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
|
lib/config.py
CHANGED
@@ -5,11 +5,12 @@ Config = SimpleNamespace(
|
|
5 |
ICON="⚡",
|
6 |
LAYOUT="wide",
|
7 |
SERVICES={
|
8 |
-
"
|
9 |
"Perplexity": "https://api.perplexity.ai",
|
10 |
"Fal": "https://fal.run",
|
11 |
},
|
12 |
TXT2IMG_HIDDEN_PARAMETERS=[
|
|
|
13 |
"enable_safety_checker",
|
14 |
"max_sequence_length",
|
15 |
"num_images",
|
@@ -22,18 +23,20 @@ Config = SimpleNamespace(
|
|
22 |
"styles",
|
23 |
"sync_mode",
|
24 |
],
|
25 |
-
TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive,
|
26 |
TXT2IMG_DEFAULT_MODEL={
|
27 |
-
|
|
|
28 |
"Fal": 1,
|
29 |
},
|
30 |
TXT2IMG_MODELS={
|
31 |
-
"
|
32 |
"black-forest-labs/flux.1-dev",
|
33 |
"black-forest-labs/flux.1-schnell",
|
34 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
35 |
],
|
36 |
"Fal": [
|
|
|
37 |
# "fal-ai/aura-flow",
|
38 |
# "fal-ai/flux-pro",
|
39 |
"fal-ai/fooocus",
|
@@ -52,40 +55,36 @@ Config = SimpleNamespace(
|
|
52 |
],
|
53 |
TXT2IMG_DEFAULT_ASPECT_RATIO="1024x1024", # fooocus aspect ratios
|
54 |
TXT2IMG_ASPECT_RATIOS=[
|
55 |
-
"704x1408",
|
56 |
-
"704x1344",
|
57 |
-
"768x1344",
|
58 |
-
"768x1280",
|
59 |
-
"832x1216",
|
60 |
-
"832x1152",
|
61 |
-
"896x1152",
|
62 |
-
"896x1088",
|
63 |
-
"960x1088",
|
64 |
-
"960x1024",
|
65 |
"1024x1024",
|
66 |
-
"1024x960",
|
67 |
-
"1088x960",
|
68 |
-
"1088x896",
|
69 |
-
"1152x896",
|
70 |
-
"1152x832",
|
71 |
-
"1216x832",
|
72 |
-
"1280x768",
|
73 |
-
"1344x768",
|
74 |
-
"1344x704",
|
75 |
-
"1408x704",
|
76 |
-
"1472x704",
|
77 |
-
"1536x640",
|
78 |
-
"1600x640",
|
79 |
-
"1664x576",
|
80 |
-
"1728x576",
|
81 |
],
|
|
|
82 |
TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
|
83 |
TXT2TXT_DEFAULT_MODEL={
|
84 |
-
"
|
85 |
"Perplexity": 3,
|
86 |
},
|
87 |
TXT2TXT_MODELS={
|
88 |
-
"
|
89 |
"codellama/codellama-34b-instruct-hf",
|
90 |
"meta-llama/llama-2-13b-chat-hf",
|
91 |
"meta-llama/meta-llama-3.1-405b-instruct-fp8",
|
|
|
5 |
ICON="⚡",
|
6 |
LAYOUT="wide",
|
7 |
SERVICES={
|
8 |
+
"Hugging Face": "https://api-inference.huggingface.co/models",
|
9 |
"Perplexity": "https://api.perplexity.ai",
|
10 |
"Fal": "https://fal.run",
|
11 |
},
|
12 |
TXT2IMG_HIDDEN_PARAMETERS=[
|
13 |
+
# sent to API but not shown in generation parameters accordion
|
14 |
"enable_safety_checker",
|
15 |
"max_sequence_length",
|
16 |
"num_images",
|
|
|
23 |
"styles",
|
24 |
"sync_mode",
|
25 |
],
|
26 |
+
TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, disfigured, deformed, mutated, malformed, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
|
27 |
TXT2IMG_DEFAULT_MODEL={
|
28 |
+
# index of model in below lists
|
29 |
+
"Hugging Face": 2,
|
30 |
"Fal": 1,
|
31 |
},
|
32 |
TXT2IMG_MODELS={
|
33 |
+
"Hugging Face": [
|
34 |
"black-forest-labs/flux.1-dev",
|
35 |
"black-forest-labs/flux.1-schnell",
|
36 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
37 |
],
|
38 |
"Fal": [
|
39 |
+
# TODO: fix these models
|
40 |
# "fal-ai/aura-flow",
|
41 |
# "fal-ai/flux-pro",
|
42 |
"fal-ai/fooocus",
|
|
|
55 |
],
|
56 |
TXT2IMG_DEFAULT_ASPECT_RATIO="1024x1024", # fooocus aspect ratios
|
57 |
TXT2IMG_ASPECT_RATIOS=[
|
58 |
+
"704x1408", # 1:2
|
59 |
+
"704x1344", # 11:21
|
60 |
+
"768x1344", # 4:7
|
61 |
+
"768x1280", # 3:5
|
62 |
+
"832x1216", # 13:19
|
63 |
+
"832x1152", # 13:18
|
64 |
+
"896x1152", # 7:9
|
65 |
+
"896x1088", # 14:17
|
66 |
+
"960x1088", # 15:17
|
67 |
+
"960x1024", # 15:16
|
68 |
"1024x1024",
|
69 |
+
"1024x960", # 16:15
|
70 |
+
"1088x960", # 17:15
|
71 |
+
"1088x896", # 17:14
|
72 |
+
"1152x896", # 9:7
|
73 |
+
"1152x832", # 18:13
|
74 |
+
"1216x832", # 19:13
|
75 |
+
"1280x768", # 5:3
|
76 |
+
"1344x768", # 7:4
|
77 |
+
"1344x704", # 21:11
|
78 |
+
"1408x704", # 2:1
|
|
|
|
|
|
|
|
|
|
|
79 |
],
|
80 |
+
# TODO: txt2img fooocus styles like "Fooocus V2" and "Fooocus Enhance" (use multiselect in UI)
|
81 |
TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
|
82 |
TXT2TXT_DEFAULT_MODEL={
|
83 |
+
"Hugging Face": 4,
|
84 |
"Perplexity": 3,
|
85 |
},
|
86 |
TXT2TXT_MODELS={
|
87 |
+
"Hugging Face": [
|
88 |
"codellama/codellama-34b-instruct-hf",
|
89 |
"meta-llama/llama-2-13b-chat-hf",
|
90 |
"meta-llama/meta-llama-3.1-405b-instruct-fp8",
|
lib/presets.py
CHANGED
@@ -2,14 +2,14 @@ from types import SimpleNamespace
|
|
2 |
|
3 |
# txt2txt services
|
4 |
ServicePresets = SimpleNamespace(
|
5 |
-
|
6 |
# every service has model and system messages
|
7 |
"frequency_penalty": 0.0,
|
8 |
"frequency_penalty_min": -2.0,
|
9 |
"frequency_penalty_max": 2.0,
|
10 |
"parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
|
11 |
},
|
12 |
-
|
13 |
"frequency_penalty": 1.0,
|
14 |
"frequency_penalty_min": 1.0,
|
15 |
"frequency_penalty_max": 2.0,
|
|
|
2 |
|
3 |
# txt2txt services
|
4 |
ServicePresets = SimpleNamespace(
|
5 |
+
HUGGING_FACE={
|
6 |
# every service has model and system messages
|
7 |
"frequency_penalty": 0.0,
|
8 |
"frequency_penalty_min": -2.0,
|
9 |
"frequency_penalty_max": 2.0,
|
10 |
"parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
|
11 |
},
|
12 |
+
PERPLEXITY={
|
13 |
"frequency_penalty": 1.0,
|
14 |
"frequency_penalty_min": 1.0,
|
15 |
"frequency_penalty_max": 2.0,
|
pages/1_💬_Text_Generation.py
CHANGED
@@ -5,8 +5,15 @@ import streamlit as st
|
|
5 |
|
6 |
from lib import Config, ServicePresets, txt2txt_generate
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# config
|
12 |
st.set_page_config(
|
@@ -16,8 +23,8 @@ st.set_page_config(
|
|
16 |
)
|
17 |
|
18 |
# initialize state
|
19 |
-
if "
|
20 |
-
st.session_state.
|
21 |
|
22 |
if "api_key_perplexity" not in st.session_state:
|
23 |
st.session_state.api_key_perplexity = ""
|
@@ -28,53 +35,36 @@ if "running" not in st.session_state:
|
|
28 |
if "txt2txt_messages" not in st.session_state:
|
29 |
st.session_state.txt2txt_messages = []
|
30 |
|
31 |
-
if "
|
32 |
-
st.session_state.
|
33 |
|
34 |
# sidebar
|
35 |
st.logo("logo.svg")
|
36 |
st.sidebar.header("Settings")
|
37 |
service = st.sidebar.selectbox(
|
38 |
"Service",
|
39 |
-
options=["
|
40 |
index=0,
|
41 |
disabled=st.session_state.running,
|
42 |
)
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
else
|
53 |
-
|
54 |
-
|
55 |
-
if service == "Perplexity" and PERPLEXITY_API_KEY is None:
|
56 |
-
st.session_state.api_key_perplexity = st.sidebar.text_input(
|
57 |
-
"API Key",
|
58 |
-
type="password",
|
59 |
-
help="Cleared on page refresh",
|
60 |
-
disabled=st.session_state.running,
|
61 |
-
value=st.session_state.api_key_perplexity,
|
62 |
-
)
|
63 |
-
else:
|
64 |
-
st.session_state.api_key_perplexity = st.session_state.api_key_perplexity
|
65 |
-
|
66 |
-
if service == "Huggingface" and HF_TOKEN is not None:
|
67 |
-
st.session_state.api_key_huggingface = HF_TOKEN
|
68 |
-
|
69 |
-
if service == "Perplexity" and PERPLEXITY_API_KEY is not None:
|
70 |
-
st.session_state.api_key_perplexity = PERPLEXITY_API_KEY
|
71 |
|
72 |
model = st.sidebar.selectbox(
|
73 |
"Model",
|
74 |
options=Config.TXT2TXT_MODELS[service],
|
75 |
index=Config.TXT2TXT_DEFAULT_MODEL[service],
|
76 |
disabled=st.session_state.running,
|
77 |
-
format_func=lambda x: x.split("/")[1] if service == "
|
78 |
)
|
79 |
system = st.sidebar.text_area(
|
80 |
"System Message",
|
@@ -84,7 +74,8 @@ system = st.sidebar.text_area(
|
|
84 |
|
85 |
# build parameters from preset
|
86 |
parameters = {}
|
87 |
-
|
|
|
88 |
for param in preset["parameters"]:
|
89 |
if param == "max_tokens":
|
90 |
parameters[param] = st.sidebar.slider(
|
@@ -152,26 +143,15 @@ if st.session_state.txt2txt_messages:
|
|
152 |
</style>
|
153 |
""")
|
154 |
|
155 |
-
|
156 |
-
col1, col2, col3 = st.columns(3)
|
157 |
with col1:
|
158 |
-
if st.button("🔄️", help="Retry last message") and len(st.session_state.txt2txt_messages) >= 2:
|
159 |
-
st.session_state.txt2txt_messages.pop()
|
160 |
-
st.session_state.txt2txt_prompt = st.session_state.txt2txt_messages.pop()["content"]
|
161 |
-
st.rerun()
|
162 |
-
|
163 |
-
# delete last message pair
|
164 |
-
with col2:
|
165 |
if st.button("❌", help="Delete last message") and len(st.session_state.txt2txt_messages) >= 2:
|
166 |
st.session_state.txt2txt_messages.pop()
|
167 |
st.session_state.txt2txt_messages.pop()
|
168 |
st.rerun()
|
169 |
-
|
170 |
-
# reset app state
|
171 |
-
with col3:
|
172 |
if st.button("🗑️", help="Clear all messages"):
|
173 |
st.session_state.txt2txt_messages = []
|
174 |
-
st.session_state.txt2txt_prompt = ""
|
175 |
st.rerun()
|
176 |
else:
|
177 |
button_container = None
|
@@ -181,28 +161,30 @@ if prompt := st.chat_input(
|
|
181 |
"What would you like to know?",
|
182 |
on_submit=lambda: setattr(st.session_state, "running", True),
|
183 |
):
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
188 |
|
189 |
if button_container:
|
190 |
button_container.empty()
|
191 |
|
192 |
messages = [{"role": "system", "content": system}]
|
193 |
messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
|
194 |
-
messages.append({"role": "user", "content":
|
195 |
parameters["messages"] = messages
|
196 |
|
197 |
with st.chat_message("user"):
|
198 |
-
st.markdown(
|
199 |
|
200 |
with st.chat_message("assistant"):
|
201 |
-
|
|
|
202 |
response = txt2txt_generate(api_key, service, model, parameters)
|
203 |
st.session_state.running = False
|
204 |
|
205 |
-
st.session_state.txt2txt_messages.append({"role": "user", "content":
|
206 |
st.session_state.txt2txt_messages.append({"role": "assistant", "content": response})
|
207 |
-
st.session_state.txt2txt_prompt = ""
|
208 |
st.rerun()
|
|
|
5 |
|
6 |
from lib import Config, ServicePresets, txt2txt_generate
|
7 |
|
8 |
+
SERVICE_SESSION = {
|
9 |
+
"Hugging Face": "api_key_hugging_face",
|
10 |
+
"Perplexity": "api_key_perplexity",
|
11 |
+
}
|
12 |
+
|
13 |
+
SESSION_TOKEN = {
|
14 |
+
"api_key_hugging_face": os.environ.get("HF_TOKEN") or None,
|
15 |
+
"api_key_perplexity": os.environ.get("PERPLEXITY_API_KEY") or None,
|
16 |
+
}
|
17 |
|
18 |
# config
|
19 |
st.set_page_config(
|
|
|
23 |
)
|
24 |
|
25 |
# initialize state
|
26 |
+
if "api_key_hugging_face" not in st.session_state:
|
27 |
+
st.session_state.api_key_hugging_face = ""
|
28 |
|
29 |
if "api_key_perplexity" not in st.session_state:
|
30 |
st.session_state.api_key_perplexity = ""
|
|
|
35 |
if "txt2txt_messages" not in st.session_state:
|
36 |
st.session_state.txt2txt_messages = []
|
37 |
|
38 |
+
if "txt2txt_seed" not in st.session_state:
|
39 |
+
st.session_state.txt2txt_seed = 0
|
40 |
|
41 |
# sidebar
|
42 |
st.logo("logo.svg")
|
43 |
st.sidebar.header("Settings")
|
44 |
service = st.sidebar.selectbox(
|
45 |
"Service",
|
46 |
+
options=["Hugging Face", "Perplexity"],
|
47 |
index=0,
|
48 |
disabled=st.session_state.running,
|
49 |
)
|
50 |
|
51 |
+
# disable API key input and hide value if set by environment variable (handle empty string value later)
|
52 |
+
for display_name, session_key in SERVICE_SESSION.items():
|
53 |
+
if service == display_name:
|
54 |
+
st.session_state[session_key] = st.sidebar.text_input(
|
55 |
+
"API Key",
|
56 |
+
type="password",
|
57 |
+
value="" if SESSION_TOKEN[session_key] else st.session_state[session_key],
|
58 |
+
disabled=bool(SESSION_TOKEN[session_key]) or st.session_state.running,
|
59 |
+
help="Set by environment variable" if SESSION_TOKEN[session_key] else "Cleared on page refresh",
|
60 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
model = st.sidebar.selectbox(
|
63 |
"Model",
|
64 |
options=Config.TXT2TXT_MODELS[service],
|
65 |
index=Config.TXT2TXT_DEFAULT_MODEL[service],
|
66 |
disabled=st.session_state.running,
|
67 |
+
format_func=lambda x: x.split("/")[1] if service == "Hugging Face" else x,
|
68 |
)
|
69 |
system = st.sidebar.text_area(
|
70 |
"System Message",
|
|
|
74 |
|
75 |
# build parameters from preset
|
76 |
parameters = {}
|
77 |
+
service_key = service.upper().replace(" ", "_")
|
78 |
+
preset = getattr(ServicePresets, service_key, {})
|
79 |
for param in preset["parameters"]:
|
80 |
if param == "max_tokens":
|
81 |
parameters[param] = st.sidebar.slider(
|
|
|
143 |
</style>
|
144 |
""")
|
145 |
|
146 |
+
col1, col2 = st.columns(2)
|
|
|
147 |
with col1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
if st.button("❌", help="Delete last message") and len(st.session_state.txt2txt_messages) >= 2:
|
149 |
st.session_state.txt2txt_messages.pop()
|
150 |
st.session_state.txt2txt_messages.pop()
|
151 |
st.rerun()
|
152 |
+
with col2:
|
|
|
|
|
153 |
if st.button("🗑️", help="Clear all messages"):
|
154 |
st.session_state.txt2txt_messages = []
|
|
|
155 |
st.rerun()
|
156 |
else:
|
157 |
button_container = None
|
|
|
161 |
"What would you like to know?",
|
162 |
on_submit=lambda: setattr(st.session_state, "running", True),
|
163 |
):
|
164 |
+
if "seed" in parameters and parameters["seed"] >= 0:
|
165 |
+
st.session_state.txt2txt_seed = parameters["seed"]
|
166 |
+
else:
|
167 |
+
st.session_state.txt2txt_seed = int(datetime.now().timestamp() * 1e6) % (1 << 53)
|
168 |
+
if "seed" in parameters:
|
169 |
+
parameters["seed"] = st.session_state.txt2txt_seed
|
170 |
|
171 |
if button_container:
|
172 |
button_container.empty()
|
173 |
|
174 |
messages = [{"role": "system", "content": system}]
|
175 |
messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
|
176 |
+
messages.append({"role": "user", "content": prompt})
|
177 |
parameters["messages"] = messages
|
178 |
|
179 |
with st.chat_message("user"):
|
180 |
+
st.markdown(prompt)
|
181 |
|
182 |
with st.chat_message("assistant"):
|
183 |
+
session_key = f"api_key_{service.lower().replace(' ', '_')}"
|
184 |
+
api_key = st.session_state[session_key] or SESSION_TOKEN[session_key]
|
185 |
response = txt2txt_generate(api_key, service, model, parameters)
|
186 |
st.session_state.running = False
|
187 |
|
188 |
+
st.session_state.txt2txt_messages.append({"role": "user", "content": prompt})
|
189 |
st.session_state.txt2txt_messages.append({"role": "assistant", "content": response})
|
|
|
190 |
st.rerun()
|
pages/2_🎨_Text_to_Image.py
CHANGED
@@ -5,8 +5,16 @@ import streamlit as st
|
|
5 |
|
6 |
from lib import Config, ModelPresets, txt2img_generate
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
PRESET_MODEL = {
|
11 |
"black-forest-labs/flux.1-dev": ModelPresets.FLUX_DEV,
|
12 |
"black-forest-labs/flux.1-schnell": ModelPresets.FLUX_SCHNELL,
|
@@ -28,15 +36,15 @@ st.set_page_config(
|
|
28 |
if "api_key_fal" not in st.session_state:
|
29 |
st.session_state.api_key_fal = ""
|
30 |
|
31 |
-
if "
|
32 |
-
st.session_state.
|
33 |
-
|
34 |
-
if "txt2img_messages" not in st.session_state:
|
35 |
-
st.session_state.txt2img_messages = []
|
36 |
|
37 |
if "running" not in st.session_state:
|
38 |
st.session_state.running = False
|
39 |
|
|
|
|
|
|
|
40 |
if "txt2img_seed" not in st.session_state:
|
41 |
st.session_state.txt2img_seed = 0
|
42 |
|
@@ -45,38 +53,21 @@ st.logo("logo.svg")
|
|
45 |
st.sidebar.header("Settings")
|
46 |
service = st.sidebar.selectbox(
|
47 |
"Service",
|
48 |
-
options=
|
49 |
index=1,
|
50 |
disabled=st.session_state.running,
|
51 |
)
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
else
|
62 |
-
|
63 |
-
|
64 |
-
if service == "Fal" and FAL_KEY is None:
|
65 |
-
st.session_state.api_key_fal = st.sidebar.text_input(
|
66 |
-
"API Key",
|
67 |
-
type="password",
|
68 |
-
help="Cleared on page refresh",
|
69 |
-
value=st.session_state.api_key_fal,
|
70 |
-
disabled=st.session_state.running,
|
71 |
-
)
|
72 |
-
else:
|
73 |
-
st.session_state.api_key_fal = st.session_state.api_key_fal
|
74 |
-
|
75 |
-
if service == "Huggingface" and HF_TOKEN is not None:
|
76 |
-
st.session_state.api_key_huggingface = HF_TOKEN
|
77 |
-
|
78 |
-
if service == "Fal" and FAL_KEY is not None:
|
79 |
-
st.session_state.api_key_fal = FAL_KEY
|
80 |
|
81 |
model = st.sidebar.selectbox(
|
82 |
"Model",
|
@@ -227,7 +218,6 @@ if st.session_state.txt2img_messages:
|
|
227 |
</style>
|
228 |
""")
|
229 |
|
230 |
-
# retry
|
231 |
col1, col2 = st.columns(2)
|
232 |
with col1:
|
233 |
if (
|
@@ -268,7 +258,8 @@ if prompt := st.chat_input(
|
|
268 |
with st.spinner("Running..."):
|
269 |
if preset.get("kwargs") is not None:
|
270 |
parameters.update(preset["kwargs"])
|
271 |
-
|
|
|
272 |
image = txt2img_generate(api_key, service, model, prompt, parameters)
|
273 |
st.session_state.running = False
|
274 |
|
|
|
5 |
|
6 |
from lib import Config, ModelPresets, txt2img_generate
|
7 |
|
8 |
+
SERVICE_SESSION = {
|
9 |
+
"Fal": "api_key_fal",
|
10 |
+
"Hugging Face": "api_key_hugging_face",
|
11 |
+
}
|
12 |
+
|
13 |
+
SESSION_TOKEN = {
|
14 |
+
"api_key_fal": os.environ.get("FAL_KEY") or None,
|
15 |
+
"api_key_hugging_face": os.environ.get("HF_TOKEN") or None,
|
16 |
+
}
|
17 |
+
|
18 |
PRESET_MODEL = {
|
19 |
"black-forest-labs/flux.1-dev": ModelPresets.FLUX_DEV,
|
20 |
"black-forest-labs/flux.1-schnell": ModelPresets.FLUX_SCHNELL,
|
|
|
36 |
if "api_key_fal" not in st.session_state:
|
37 |
st.session_state.api_key_fal = ""
|
38 |
|
39 |
+
if "api_key_hugging_face" not in st.session_state:
|
40 |
+
st.session_state.api_key_hugging_face = ""
|
|
|
|
|
|
|
41 |
|
42 |
if "running" not in st.session_state:
|
43 |
st.session_state.running = False
|
44 |
|
45 |
+
if "txt2img_messages" not in st.session_state:
|
46 |
+
st.session_state.txt2img_messages = []
|
47 |
+
|
48 |
if "txt2img_seed" not in st.session_state:
|
49 |
st.session_state.txt2img_seed = 0
|
50 |
|
|
|
53 |
st.sidebar.header("Settings")
|
54 |
service = st.sidebar.selectbox(
|
55 |
"Service",
|
56 |
+
options=list(SERVICE_SESSION.keys()),
|
57 |
index=1,
|
58 |
disabled=st.session_state.running,
|
59 |
)
|
60 |
|
61 |
+
# disable API key input and hide value if set by environment variable (handle empty string value later)
|
62 |
+
for display_name, session_key in SERVICE_SESSION.items():
|
63 |
+
if service == display_name:
|
64 |
+
st.session_state[session_key] = st.sidebar.text_input(
|
65 |
+
"API Key",
|
66 |
+
type="password",
|
67 |
+
value="" if SESSION_TOKEN[session_key] else st.session_state[session_key],
|
68 |
+
disabled=bool(SESSION_TOKEN[session_key]) or st.session_state.running,
|
69 |
+
help="Set by environment variable" if SESSION_TOKEN[session_key] else "Cleared on page refresh",
|
70 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
model = st.sidebar.selectbox(
|
73 |
"Model",
|
|
|
218 |
</style>
|
219 |
""")
|
220 |
|
|
|
221 |
col1, col2 = st.columns(2)
|
222 |
with col1:
|
223 |
if (
|
|
|
258 |
with st.spinner("Running..."):
|
259 |
if preset.get("kwargs") is not None:
|
260 |
parameters.update(preset["kwargs"])
|
261 |
+
session_key = f"api_key_{service.lower().replace(' ', '_')}"
|
262 |
+
api_key = st.session_state[session_key] or SESSION_TOKEN[session_key]
|
263 |
image = txt2img_generate(api_key, service, model, prompt, parameters)
|
264 |
st.session_state.running = False
|
265 |
|