adamelliotfields
commited on
Add BFL API
Browse files- lib/api.py +39 -1
- lib/config.py +19 -9
- lib/presets.py +71 -15
- pages/2_🎨_Text_to_Image.py +28 -18
lib/api.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import base64
|
2 |
import io
|
|
|
3 |
|
4 |
import httpx
|
5 |
import streamlit as st
|
@@ -20,6 +21,7 @@ def txt2txt_generate(api_key, service, model, parameters, **kwargs):
|
|
20 |
return st.write_stream(stream)
|
21 |
except APIError as e:
|
22 |
# OpenAI uses this message for streaming errors and attaches response.error to error.body
|
|
|
23 |
return e.body if e.message == "An error occurred during streaming" else e.message
|
24 |
except Exception as e:
|
25 |
return str(e)
|
@@ -31,19 +33,28 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
31 |
headers["Authorization"] = f"Bearer {api_key}"
|
32 |
headers["X-Wait-For-Model"] = "true"
|
33 |
headers["X-Use-Cache"] = "false"
|
|
|
34 |
if service == "Fal":
|
35 |
headers["Authorization"] = f"Key {api_key}"
|
36 |
|
|
|
|
|
|
|
37 |
json = {}
|
38 |
if service == "Hugging Face":
|
39 |
json = {
|
40 |
"inputs": inputs,
|
41 |
"parameters": {**parameters, **kwargs},
|
42 |
}
|
|
|
43 |
if service == "Fal":
|
44 |
json = {**parameters, **kwargs}
|
45 |
json["prompt"] = inputs
|
46 |
|
|
|
|
|
|
|
|
|
47 |
base_url = f"{Config.SERVICES[service]}/{model}"
|
48 |
|
49 |
try:
|
@@ -51,8 +62,9 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
51 |
if response.status_code // 100 == 2: # 2xx
|
52 |
if service == "Hugging Face":
|
53 |
return Image.open(io.BytesIO(response.content))
|
|
|
54 |
if service == "Fal":
|
55 |
-
#
|
56 |
if parameters.get("sync_mode", True):
|
57 |
bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
|
58 |
return Image.open(io.BytesIO(bytes))
|
@@ -60,6 +72,32 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
60 |
url = response.json()["images"][0]["url"]
|
61 |
image = httpx.get(url, headers=headers, timeout=Config.TXT2IMG_TIMEOUT)
|
62 |
return Image.open(io.BytesIO(image.content))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
else:
|
64 |
return f"Error: {response.status_code} {response.text}"
|
65 |
except Exception as e:
|
|
|
1 |
import base64
|
2 |
import io
|
3 |
+
import time
|
4 |
|
5 |
import httpx
|
6 |
import streamlit as st
|
|
|
21 |
return st.write_stream(stream)
|
22 |
except APIError as e:
|
23 |
# OpenAI uses this message for streaming errors and attaches response.error to error.body
|
24 |
+
# https://github.com/openai/openai-python/blob/v1.0.0/src/openai/_streaming.py#L59
|
25 |
return e.body if e.message == "An error occurred during streaming" else e.message
|
26 |
except Exception as e:
|
27 |
return str(e)
|
|
|
33 |
headers["Authorization"] = f"Bearer {api_key}"
|
34 |
headers["X-Wait-For-Model"] = "true"
|
35 |
headers["X-Use-Cache"] = "false"
|
36 |
+
|
37 |
if service == "Fal":
|
38 |
headers["Authorization"] = f"Key {api_key}"
|
39 |
|
40 |
+
if service == "BFL":
|
41 |
+
headers["x-key"] = api_key
|
42 |
+
|
43 |
json = {}
|
44 |
if service == "Hugging Face":
|
45 |
json = {
|
46 |
"inputs": inputs,
|
47 |
"parameters": {**parameters, **kwargs},
|
48 |
}
|
49 |
+
|
50 |
if service == "Fal":
|
51 |
json = {**parameters, **kwargs}
|
52 |
json["prompt"] = inputs
|
53 |
|
54 |
+
if service == "BFL":
|
55 |
+
json = {**parameters, **kwargs}
|
56 |
+
json["prompt"] = inputs
|
57 |
+
|
58 |
base_url = f"{Config.SERVICES[service]}/{model}"
|
59 |
|
60 |
try:
|
|
|
62 |
if response.status_code // 100 == 2: # 2xx
|
63 |
if service == "Hugging Face":
|
64 |
return Image.open(io.BytesIO(response.content))
|
65 |
+
|
66 |
if service == "Fal":
|
67 |
+
# Sync mode means wait for image base64 string instead of CDN link
|
68 |
if parameters.get("sync_mode", True):
|
69 |
bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
|
70 |
return Image.open(io.BytesIO(bytes))
|
|
|
72 |
url = response.json()["images"][0]["url"]
|
73 |
image = httpx.get(url, headers=headers, timeout=Config.TXT2IMG_TIMEOUT)
|
74 |
return Image.open(io.BytesIO(image.content))
|
75 |
+
|
76 |
+
# BFL is async so we need to poll for result
|
77 |
+
# https://api.bfl.ml/docs
|
78 |
+
if service == "BFL":
|
79 |
+
id = response.json()["id"]
|
80 |
+
url = f"{Config.SERVICES[service]}/get_result?id={id}"
|
81 |
+
|
82 |
+
retries = 0
|
83 |
+
while retries < Config.TXT2IMG_TIMEOUT:
|
84 |
+
response = httpx.get(url, timeout=Config.TXT2IMG_TIMEOUT)
|
85 |
+
if response.status_code // 100 != 2:
|
86 |
+
return f"Error: {response.status_code} {response.text}"
|
87 |
+
|
88 |
+
if response.json()["status"] == "Ready":
|
89 |
+
image = httpx.get(
|
90 |
+
response.json()["result"]["sample"],
|
91 |
+
headers=headers,
|
92 |
+
timeout=Config.TXT2IMG_TIMEOUT,
|
93 |
+
)
|
94 |
+
return Image.open(io.BytesIO(image.content))
|
95 |
+
|
96 |
+
retries += 1
|
97 |
+
time.sleep(1)
|
98 |
+
|
99 |
+
return "Error: API timeout"
|
100 |
+
|
101 |
else:
|
102 |
return f"Error: {response.status_code} {response.text}"
|
103 |
except Exception as e:
|
lib/config.py
CHANGED
@@ -5,11 +5,12 @@ Config = SimpleNamespace(
|
|
5 |
ICON="⚡",
|
6 |
LAYOUT="wide",
|
7 |
SERVICES={
|
|
|
|
|
8 |
"Hugging Face": "https://api-inference.huggingface.co/models",
|
9 |
"Perplexity": "https://api.perplexity.ai",
|
10 |
-
"Fal": "https://fal.run",
|
11 |
},
|
12 |
-
TXT2IMG_TIMEOUT=
|
13 |
TXT2IMG_HIDDEN_PARAMETERS=[
|
14 |
# sent to API but not shown in generation parameters accordion
|
15 |
"enable_safety_checker",
|
@@ -26,23 +27,33 @@ Config = SimpleNamespace(
|
|
26 |
],
|
27 |
TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, disfigured, deformed, mutated, malformed, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
|
28 |
TXT2IMG_DEFAULT_MODEL={
|
29 |
-
# index of model in below lists
|
|
|
|
|
30 |
"Hugging Face": 2,
|
31 |
-
"Fal": 2,
|
32 |
},
|
33 |
TXT2IMG_MODELS={
|
34 |
-
|
35 |
-
|
36 |
-
"
|
37 |
-
"
|
|
|
38 |
],
|
39 |
"Fal": [
|
40 |
"fal-ai/aura-flow",
|
|
|
|
|
41 |
"fal-ai/flux-pro",
|
|
|
42 |
"fal-ai/fooocus",
|
43 |
"fal-ai/kolors",
|
44 |
"fal-ai/stable-diffusion-v3-medium",
|
45 |
],
|
|
|
|
|
|
|
|
|
|
|
46 |
},
|
47 |
TXT2IMG_DEFAULT_IMAGE_SIZE="square_hd", # fal image sizes
|
48 |
TXT2IMG_IMAGE_SIZES=[
|
@@ -76,7 +87,6 @@ Config = SimpleNamespace(
|
|
76 |
"1344x704", # 21:11
|
77 |
"1408x704", # 2:1
|
78 |
],
|
79 |
-
# TODO: txt2img fooocus styles like "Fooocus V2" and "Fooocus Enhance" (use multiselect in UI)
|
80 |
TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
|
81 |
TXT2TXT_DEFAULT_MODEL={
|
82 |
"Hugging Face": 4,
|
|
|
5 |
ICON="⚡",
|
6 |
LAYOUT="wide",
|
7 |
SERVICES={
|
8 |
+
"BFL": "https://api.bfl.ml/v1",
|
9 |
+
"Fal": "https://fal.run",
|
10 |
"Hugging Face": "https://api-inference.huggingface.co/models",
|
11 |
"Perplexity": "https://api.perplexity.ai",
|
|
|
12 |
},
|
13 |
+
TXT2IMG_TIMEOUT=60,
|
14 |
TXT2IMG_HIDDEN_PARAMETERS=[
|
15 |
# sent to API but not shown in generation parameters accordion
|
16 |
"enable_safety_checker",
|
|
|
27 |
],
|
28 |
TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, disfigured, deformed, mutated, malformed, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
|
29 |
TXT2IMG_DEFAULT_MODEL={
|
30 |
+
# The index of model in below lists
|
31 |
+
"BFL": 2,
|
32 |
+
"Fal": 0,
|
33 |
"Hugging Face": 2,
|
|
|
34 |
},
|
35 |
TXT2IMG_MODELS={
|
36 |
+
# Model IDs referenced in Text_to_Image.py
|
37 |
+
"BFL": [
|
38 |
+
"flux-dev",
|
39 |
+
"flux-pro",
|
40 |
+
"flux-pro-1.1",
|
41 |
],
|
42 |
"Fal": [
|
43 |
"fal-ai/aura-flow",
|
44 |
+
"fal-ai/flux/dev",
|
45 |
+
"fal-ai/flux/schnell",
|
46 |
"fal-ai/flux-pro",
|
47 |
+
"fal-ai/flux-pro/v1.1",
|
48 |
"fal-ai/fooocus",
|
49 |
"fal-ai/kolors",
|
50 |
"fal-ai/stable-diffusion-v3-medium",
|
51 |
],
|
52 |
+
"Hugging Face": [
|
53 |
+
"black-forest-labs/flux.1-dev",
|
54 |
+
"black-forest-labs/flux.1-schnell",
|
55 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
56 |
+
],
|
57 |
},
|
58 |
TXT2IMG_DEFAULT_IMAGE_SIZE="square_hd", # fal image sizes
|
59 |
TXT2IMG_IMAGE_SIZES=[
|
|
|
87 |
"1344x704", # 21:11
|
88 |
"1408x704", # 2:1
|
89 |
],
|
|
|
90 |
TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
|
91 |
TXT2TXT_DEFAULT_MODEL={
|
92 |
"Hugging Face": 4,
|
lib/presets.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
from types import SimpleNamespace
|
2 |
|
3 |
-
# txt2txt
|
4 |
ServicePresets = SimpleNamespace(
|
|
|
5 |
HUGGING_FACE={
|
6 |
-
# every service has model and system messages
|
7 |
"frequency_penalty": 0.0,
|
8 |
"frequency_penalty_min": -2.0,
|
9 |
"frequency_penalty_max": 2.0,
|
@@ -17,7 +17,7 @@ ServicePresets = SimpleNamespace(
|
|
17 |
},
|
18 |
)
|
19 |
|
20 |
-
# txt2img
|
21 |
ModelPresets = SimpleNamespace(
|
22 |
AURA_FLOW={
|
23 |
"name": "AuraFlow",
|
@@ -30,33 +30,89 @@ ModelPresets = SimpleNamespace(
|
|
30 |
"parameters": ["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
|
31 |
"kwargs": {"num_images": 1, "sync_mode": False},
|
32 |
},
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
"name": "FLUX.1 Dev",
|
35 |
"num_inference_steps": 28,
|
36 |
"num_inference_steps_min": 10,
|
37 |
"num_inference_steps_max": 50,
|
38 |
-
"guidance_scale": 3.
|
39 |
-
"guidance_scale_min": 1.
|
40 |
-
"guidance_scale_max":
|
41 |
-
"parameters": ["width", "height", "
|
42 |
-
"kwargs": {"
|
43 |
},
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
"name": "FLUX.1 Pro",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
"num_inference_steps": 28,
|
47 |
"num_inference_steps_min": 10,
|
48 |
"num_inference_steps_max": 50,
|
49 |
-
"guidance_scale": 3.
|
50 |
-
"guidance_scale_min": 1.
|
51 |
-
"guidance_scale_max":
|
52 |
"parameters": ["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
53 |
"kwargs": {"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
|
54 |
},
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
"name": "FLUX.1 Schnell",
|
57 |
"num_inference_steps": 4,
|
58 |
"num_inference_steps_min": 1,
|
59 |
-
"num_inference_steps_max":
|
60 |
"parameters": ["width", "height", "num_inference_steps"],
|
61 |
"kwargs": {"guidance_scale": 0.0, "max_sequence_length": 256},
|
62 |
},
|
|
|
1 |
from types import SimpleNamespace
|
2 |
|
3 |
+
# txt2txt
|
4 |
ServicePresets = SimpleNamespace(
|
5 |
+
# Every service has model and system messages
|
6 |
HUGGING_FACE={
|
|
|
7 |
"frequency_penalty": 0.0,
|
8 |
"frequency_penalty_min": -2.0,
|
9 |
"frequency_penalty_max": 2.0,
|
|
|
17 |
},
|
18 |
)
|
19 |
|
20 |
+
# txt2img
|
21 |
ModelPresets = SimpleNamespace(
|
22 |
AURA_FLOW={
|
23 |
"name": "AuraFlow",
|
|
|
30 |
"parameters": ["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
|
31 |
"kwargs": {"num_images": 1, "sync_mode": False},
|
32 |
},
|
33 |
+
FLUX_1_1_PRO_BFL={
|
34 |
+
"name": "FLUX1.1 Pro",
|
35 |
+
"parameters": ["seed", "width", "height", "prompt_upsampling"],
|
36 |
+
"kwargs": {"safety_tolerance": 6},
|
37 |
+
},
|
38 |
+
FLUX_PRO_BFL={
|
39 |
+
"name": "FLUX.1 Pro",
|
40 |
+
"guidance_scale": 2.5,
|
41 |
+
"guidance_scale_min": 1.5,
|
42 |
+
"guidance_scale_max": 5.0,
|
43 |
+
"num_inference_steps": 40,
|
44 |
+
"num_inference_steps_min": 10,
|
45 |
+
"num_inference_steps_max": 50,
|
46 |
+
"parameters": ["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
|
47 |
+
"kwargs": {"safety_tolerance": 6, "interval": 1},
|
48 |
+
},
|
49 |
+
FLUX_DEV_BFL={
|
50 |
"name": "FLUX.1 Dev",
|
51 |
"num_inference_steps": 28,
|
52 |
"num_inference_steps_min": 10,
|
53 |
"num_inference_steps_max": 50,
|
54 |
+
"guidance_scale": 3.0,
|
55 |
+
"guidance_scale_min": 1.5,
|
56 |
+
"guidance_scale_max": 5.0,
|
57 |
+
"parameters": ["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
|
58 |
+
"kwargs": {"safety_tolerance": 6},
|
59 |
},
|
60 |
+
FLUX_1_1_PRO_FAL={
|
61 |
+
"name": "FLUX1.1 Pro",
|
62 |
+
"parameters": ["seed", "image_size"],
|
63 |
+
"kwargs": {
|
64 |
+
"num_images": 1,
|
65 |
+
"sync_mode": False,
|
66 |
+
"safety_tolerance": 6,
|
67 |
+
"enable_safety_checker": False,
|
68 |
+
},
|
69 |
+
},
|
70 |
+
FLUX_PRO_FAL={
|
71 |
"name": "FLUX.1 Pro",
|
72 |
+
"guidance_scale": 2.5,
|
73 |
+
"guidance_scale_min": 1.5,
|
74 |
+
"guidance_scale_max": 5.0,
|
75 |
+
"num_inference_steps": 40,
|
76 |
+
"num_inference_steps_min": 10,
|
77 |
+
"num_inference_steps_max": 50,
|
78 |
+
"parameters": ["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
79 |
+
"kwargs": {"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
|
80 |
+
},
|
81 |
+
FLUX_DEV_FAL={
|
82 |
+
"name": "FLUX.1 Dev",
|
83 |
"num_inference_steps": 28,
|
84 |
"num_inference_steps_min": 10,
|
85 |
"num_inference_steps_max": 50,
|
86 |
+
"guidance_scale": 3.0,
|
87 |
+
"guidance_scale_min": 1.5,
|
88 |
+
"guidance_scale_max": 5.0,
|
89 |
"parameters": ["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
90 |
"kwargs": {"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
|
91 |
},
|
92 |
+
FLUX_SCHNELL_FAL={
|
93 |
+
"name": "FLUX.1 Schnell",
|
94 |
+
"num_inference_steps": 4,
|
95 |
+
"num_inference_steps_min": 1,
|
96 |
+
"num_inference_steps_max": 12,
|
97 |
+
"parameters": ["seed", "image_size", "num_inference_steps"],
|
98 |
+
"kwargs": {"num_images": 1, "sync_mode": False, "enable_safety_checker": False},
|
99 |
+
},
|
100 |
+
FLUX_DEV_HF={
|
101 |
+
"name": "FLUX.1 Dev",
|
102 |
+
"num_inference_steps": 28,
|
103 |
+
"num_inference_steps_min": 10,
|
104 |
+
"num_inference_steps_max": 50,
|
105 |
+
"guidance_scale": 3.0,
|
106 |
+
"guidance_scale_min": 1.5,
|
107 |
+
"guidance_scale_max": 5.0,
|
108 |
+
"parameters": ["width", "height", "guidance_scale", "num_inference_steps"],
|
109 |
+
"kwargs": {"max_sequence_length": 512},
|
110 |
+
},
|
111 |
+
FLUX_SCHNELL_HF={
|
112 |
"name": "FLUX.1 Schnell",
|
113 |
"num_inference_steps": 4,
|
114 |
"num_inference_steps_min": 1,
|
115 |
+
"num_inference_steps_max": 12,
|
116 |
"parameters": ["width", "height", "num_inference_steps"],
|
117 |
"kwargs": {"guidance_scale": 0.0, "max_sequence_length": 256},
|
118 |
},
|
pages/2_🎨_Text_to_Image.py
CHANGED
@@ -6,34 +6,45 @@ import streamlit as st
|
|
6 |
from lib import Config, ModelPresets, txt2img_generate
|
7 |
|
8 |
SERVICE_SESSION = {
|
|
|
9 |
"Fal": "api_key_fal",
|
10 |
"Hugging Face": "api_key_hugging_face",
|
11 |
}
|
12 |
|
13 |
SESSION_TOKEN = {
|
|
|
14 |
"api_key_fal": os.environ.get("FAL_KEY") or None,
|
15 |
"api_key_hugging_face": os.environ.get("HF_TOKEN") or None,
|
16 |
}
|
17 |
|
|
|
18 |
PRESET_MODEL = {
|
19 |
-
"black-forest-labs/flux.1-dev": ModelPresets.
|
20 |
-
"black-forest-labs/flux.1-schnell": ModelPresets.
|
21 |
"stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
|
22 |
"fal-ai/aura-flow": ModelPresets.AURA_FLOW,
|
23 |
-
"fal-ai/flux
|
|
|
|
|
|
|
24 |
"fal-ai/fooocus": ModelPresets.FOOOCUS,
|
25 |
"fal-ai/kolors": ModelPresets.KOLORS,
|
26 |
"fal-ai/stable-diffusion-v3-medium": ModelPresets.STABLE_DIFFUSION_3,
|
|
|
|
|
|
|
27 |
}
|
28 |
|
29 |
-
# config
|
30 |
st.set_page_config(
|
31 |
page_title=f"{Config.TITLE} | Text to Image",
|
32 |
page_icon=Config.ICON,
|
33 |
layout=Config.LAYOUT,
|
34 |
)
|
35 |
|
36 |
-
#
|
|
|
|
|
|
|
37 |
if "api_key_fal" not in st.session_state:
|
38 |
st.session_state.api_key_fal = ""
|
39 |
|
@@ -49,7 +60,6 @@ if "txt2img_messages" not in st.session_state:
|
|
49 |
if "txt2img_seed" not in st.session_state:
|
50 |
st.session_state.txt2img_seed = 0
|
51 |
|
52 |
-
# sidebar
|
53 |
st.logo("logo.svg")
|
54 |
st.sidebar.header("Settings")
|
55 |
service = st.sidebar.selectbox(
|
@@ -59,7 +69,7 @@ service = st.sidebar.selectbox(
|
|
59 |
disabled=st.session_state.running,
|
60 |
)
|
61 |
|
62 |
-
#
|
63 |
for display_name, session_key in SERVICE_SESSION.items():
|
64 |
if service == display_name:
|
65 |
st.session_state[session_key] = st.sidebar.text_input(
|
@@ -75,7 +85,6 @@ model = st.sidebar.selectbox(
|
|
75 |
options=Config.TXT2IMG_MODELS[service],
|
76 |
index=Config.TXT2IMG_DEFAULT_MODEL[service],
|
77 |
disabled=st.session_state.running,
|
78 |
-
format_func=lambda x: x.split("/")[1],
|
79 |
)
|
80 |
|
81 |
# heading
|
@@ -84,7 +93,7 @@ st.html("""
|
|
84 |
<p>Generate an image from a text prompt.</p>
|
85 |
""")
|
86 |
|
87 |
-
#
|
88 |
parameters = {}
|
89 |
preset = PRESET_MODEL[model]
|
90 |
for param in preset["parameters"]:
|
@@ -134,7 +143,7 @@ for param in preset["parameters"]:
|
|
134 |
value=Config.TXT2IMG_DEFAULT_ASPECT_RATIO,
|
135 |
disabled=st.session_state.running,
|
136 |
)
|
137 |
-
if param
|
138 |
parameters[param] = st.sidebar.slider(
|
139 |
"Guidance Scale",
|
140 |
preset["guidance_scale_min"],
|
@@ -143,7 +152,7 @@ for param in preset["parameters"]:
|
|
143 |
0.1,
|
144 |
disabled=st.session_state.running,
|
145 |
)
|
146 |
-
if param
|
147 |
parameters[param] = st.sidebar.slider(
|
148 |
"Inference Steps",
|
149 |
preset["num_inference_steps_min"],
|
@@ -152,20 +161,20 @@ for param in preset["parameters"]:
|
|
152 |
1,
|
153 |
disabled=st.session_state.running,
|
154 |
)
|
155 |
-
if param
|
156 |
parameters[param] = st.sidebar.checkbox(
|
157 |
-
"
|
158 |
value=False,
|
159 |
disabled=st.session_state.running,
|
160 |
)
|
161 |
-
if param == "
|
162 |
parameters[param] = st.sidebar.checkbox(
|
163 |
-
"Prompt
|
164 |
value=False,
|
165 |
disabled=st.session_state.running,
|
166 |
)
|
167 |
|
168 |
-
#
|
169 |
for message in st.session_state.txt2img_messages:
|
170 |
role = message["role"]
|
171 |
with st.chat_message(role):
|
@@ -202,7 +211,7 @@ for message in st.session_state.txt2img_messages:
|
|
202 |
""")
|
203 |
st.write(message["content"]) # success will be image, error will be text
|
204 |
|
205 |
-
#
|
206 |
if st.session_state.txt2img_messages:
|
207 |
button_container = st.empty()
|
208 |
with button_container.container():
|
@@ -235,7 +244,8 @@ if st.session_state.txt2img_messages:
|
|
235 |
else:
|
236 |
button_container = None
|
237 |
|
238 |
-
#
|
|
|
239 |
if prompt := st.chat_input(
|
240 |
"What do you want to see?",
|
241 |
on_submit=lambda: setattr(st.session_state, "running", True),
|
|
|
6 |
from lib import Config, ModelPresets, txt2img_generate
|
7 |
|
8 |
SERVICE_SESSION = {
|
9 |
+
"BFL": "api_key_bfl",
|
10 |
"Fal": "api_key_fal",
|
11 |
"Hugging Face": "api_key_hugging_face",
|
12 |
}
|
13 |
|
14 |
SESSION_TOKEN = {
|
15 |
+
"api_key_bfl": os.environ.get("BFL_API_KEY") or None,
|
16 |
"api_key_fal": os.environ.get("FAL_KEY") or None,
|
17 |
"api_key_hugging_face": os.environ.get("HF_TOKEN") or None,
|
18 |
}
|
19 |
|
20 |
+
# Model IDs in lib/config.py
|
21 |
PRESET_MODEL = {
|
22 |
+
"black-forest-labs/flux.1-dev": ModelPresets.FLUX_DEV_HF,
|
23 |
+
"black-forest-labs/flux.1-schnell": ModelPresets.FLUX_SCHNELL_HF,
|
24 |
"stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
|
25 |
"fal-ai/aura-flow": ModelPresets.AURA_FLOW,
|
26 |
+
"fal-ai/flux/dev": ModelPresets.FLUX_DEV_FAL,
|
27 |
+
"fal-ai/flux/schnell": ModelPresets.FLUX_SCHNELL_FAL,
|
28 |
+
"fal-ai/flux-pro": ModelPresets.FLUX_PRO_FAL,
|
29 |
+
"fal-ai/flux-pro/v1.1": ModelPresets.FLUX_1_1_PRO_FAL,
|
30 |
"fal-ai/fooocus": ModelPresets.FOOOCUS,
|
31 |
"fal-ai/kolors": ModelPresets.KOLORS,
|
32 |
"fal-ai/stable-diffusion-v3-medium": ModelPresets.STABLE_DIFFUSION_3,
|
33 |
+
"flux-pro-1.1": ModelPresets.FLUX_1_1_PRO_BFL,
|
34 |
+
"flux-pro": ModelPresets.FLUX_PRO_BFL,
|
35 |
+
"flux-dev": ModelPresets.FLUX_DEV_BFL,
|
36 |
}
|
37 |
|
|
|
38 |
st.set_page_config(
|
39 |
page_title=f"{Config.TITLE} | Text to Image",
|
40 |
page_icon=Config.ICON,
|
41 |
layout=Config.LAYOUT,
|
42 |
)
|
43 |
|
44 |
+
# Initialize Streamlit session state
|
45 |
+
if "api_key_bfl" not in st.session_state:
|
46 |
+
st.session_state.api_key_bfl = ""
|
47 |
+
|
48 |
if "api_key_fal" not in st.session_state:
|
49 |
st.session_state.api_key_fal = ""
|
50 |
|
|
|
60 |
if "txt2img_seed" not in st.session_state:
|
61 |
st.session_state.txt2img_seed = 0
|
62 |
|
|
|
63 |
st.logo("logo.svg")
|
64 |
st.sidebar.header("Settings")
|
65 |
service = st.sidebar.selectbox(
|
|
|
69 |
disabled=st.session_state.running,
|
70 |
)
|
71 |
|
72 |
+
# Disable API key input and hide value if set by environment variable; handle empty string value later.
|
73 |
for display_name, session_key in SERVICE_SESSION.items():
|
74 |
if service == display_name:
|
75 |
st.session_state[session_key] = st.sidebar.text_input(
|
|
|
85 |
options=Config.TXT2IMG_MODELS[service],
|
86 |
index=Config.TXT2IMG_DEFAULT_MODEL[service],
|
87 |
disabled=st.session_state.running,
|
|
|
88 |
)
|
89 |
|
90 |
# heading
|
|
|
93 |
<p>Generate an image from a text prompt.</p>
|
94 |
""")
|
95 |
|
96 |
+
# Build parameters from preset by rendering the appropriate input widgets
|
97 |
parameters = {}
|
98 |
preset = PRESET_MODEL[model]
|
99 |
for param in preset["parameters"]:
|
|
|
143 |
value=Config.TXT2IMG_DEFAULT_ASPECT_RATIO,
|
144 |
disabled=st.session_state.running,
|
145 |
)
|
146 |
+
if param in ["guidance_scale", "guidance"]:
|
147 |
parameters[param] = st.sidebar.slider(
|
148 |
"Guidance Scale",
|
149 |
preset["guidance_scale_min"],
|
|
|
152 |
0.1,
|
153 |
disabled=st.session_state.running,
|
154 |
)
|
155 |
+
if param in ["num_inference_steps", "steps"]:
|
156 |
parameters[param] = st.sidebar.slider(
|
157 |
"Inference Steps",
|
158 |
preset["num_inference_steps_min"],
|
|
|
161 |
1,
|
162 |
disabled=st.session_state.running,
|
163 |
)
|
164 |
+
if param in ["expand_prompt", "prompt_expansion"]:
|
165 |
parameters[param] = st.sidebar.checkbox(
|
166 |
+
"Prompt Expansion",
|
167 |
value=False,
|
168 |
disabled=st.session_state.running,
|
169 |
)
|
170 |
+
if param == "prompt_upsampling":
|
171 |
parameters[param] = st.sidebar.checkbox(
|
172 |
+
"Prompt Upsampling",
|
173 |
value=False,
|
174 |
disabled=st.session_state.running,
|
175 |
)
|
176 |
|
177 |
+
# Wrap the prompt in an accordion to display additional parameters
|
178 |
for message in st.session_state.txt2img_messages:
|
179 |
role = message["role"]
|
180 |
with st.chat_message(role):
|
|
|
211 |
""")
|
212 |
st.write(message["content"]) # success will be image, error will be text
|
213 |
|
214 |
+
# Buttons for deleting last generation or clearing all generations
|
215 |
if st.session_state.txt2img_messages:
|
216 |
button_container = st.empty()
|
217 |
with button_container.container():
|
|
|
244 |
else:
|
245 |
button_container = None
|
246 |
|
247 |
+
# Set running state to True and show spinner while loading.
|
248 |
+
# Update state and refresh on response; errors will be displayed as chat messages.
|
249 |
if prompt := st.chat_input(
|
250 |
"What do you want to see?",
|
251 |
on_submit=lambda: setattr(st.session_state, "running", True),
|