adamelliotfields
commited on
Add initial Fal txt2img models
Browse files- lib/api.py +27 -16
- lib/config.py +69 -13
- lib/presets.py +102 -7
- pages/2_🎨_Text_to_Image.py +84 -45
lib/api.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import io
|
2 |
|
3 |
import requests
|
@@ -24,25 +25,35 @@ def txt2txt_generate(api_key, service, model, parameters, **kwargs):
|
|
24 |
|
25 |
|
26 |
def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
27 |
-
headers = {
|
28 |
-
|
29 |
-
"
|
30 |
-
"X-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
base_url = f"{Config.SERVICES[service]}/{model}"
|
33 |
|
34 |
try:
|
35 |
-
response = requests.post(
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
"
|
41 |
-
|
42 |
-
)
|
43 |
-
if response.status_code == 200:
|
44 |
-
return Image.open(io.BytesIO(response.content))
|
45 |
else:
|
46 |
-
return f"Error: {response.status_code}
|
47 |
except Exception as e:
|
48 |
return str(e)
|
|
|
1 |
+
import base64
|
2 |
import io
|
3 |
|
4 |
import requests
|
|
|
25 |
|
26 |
|
27 |
def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
28 |
+
headers = {}
|
29 |
+
if service == "Huggingface":
|
30 |
+
headers["Authorization"] = f"Bearer {api_key}"
|
31 |
+
headers["X-Wait-For-Model"] = "true"
|
32 |
+
headers["X-Use-Cache"] = "false"
|
33 |
+
if service == "Fal":
|
34 |
+
headers["Authorization"] = f"Key {api_key}"
|
35 |
+
|
36 |
+
json = {}
|
37 |
+
if service == "Huggingface":
|
38 |
+
json = {
|
39 |
+
"inputs": inputs,
|
40 |
+
"parameters": {**parameters, **kwargs},
|
41 |
+
}
|
42 |
+
if service == "Fal":
|
43 |
+
json = {**parameters, **kwargs}
|
44 |
+
json["prompt"] = inputs
|
45 |
+
|
46 |
base_url = f"{Config.SERVICES[service]}/{model}"
|
47 |
|
48 |
try:
|
49 |
+
response = requests.post(base_url, headers=headers, json=json)
|
50 |
+
if response.status_code // 100 == 2: # 2xx
|
51 |
+
if service == "Huggingface":
|
52 |
+
return Image.open(io.BytesIO(response.content))
|
53 |
+
if service == "Fal":
|
54 |
+
bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
|
55 |
+
return Image.open(io.BytesIO(bytes))
|
|
|
|
|
|
|
56 |
else:
|
57 |
+
return f"Error: {response.status_code} {response.text}"
|
58 |
except Exception as e:
|
59 |
return str(e)
|
lib/config.py
CHANGED
@@ -7,22 +7,78 @@ Config = SimpleNamespace(
|
|
7 |
SERVICES={
|
8 |
"Huggingface": "https://api-inference.huggingface.co/models",
|
9 |
"Perplexity": "https://api.perplexity.ai",
|
|
|
10 |
},
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
"
|
15 |
-
"
|
16 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
],
|
18 |
-
|
19 |
-
|
20 |
-
"
|
21 |
-
"
|
22 |
-
|
23 |
-
|
24 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
|
27 |
TXT2TXT_DEFAULT_MODEL={
|
28 |
"Huggingface": 4,
|
|
|
7 |
SERVICES={
|
8 |
"Huggingface": "https://api-inference.huggingface.co/models",
|
9 |
"Perplexity": "https://api.perplexity.ai",
|
10 |
+
"Fal": "https://fal.run",
|
11 |
},
|
12 |
+
TXT2IMG_HIDDEN_PARAMETERS=[
|
13 |
+
"enable_safety_checker",
|
14 |
+
"max_sequence_length",
|
15 |
+
"num_images",
|
16 |
+
"output_format",
|
17 |
+
"performance",
|
18 |
+
"safety_tolerance",
|
19 |
+
"scheduler",
|
20 |
+
"sharpness",
|
21 |
+
"style",
|
22 |
+
"styles",
|
23 |
+
"sync_mode",
|
24 |
],
|
25 |
+
TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, malformed, mutated, disgusting, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
|
26 |
+
TXT2IMG_DEFAULT_MODEL={
|
27 |
+
"Huggingface": 2,
|
28 |
+
"Fal": 1,
|
29 |
+
},
|
30 |
+
TXT2IMG_MODELS={
|
31 |
+
"Huggingface": [
|
32 |
+
"black-forest-labs/flux.1-dev",
|
33 |
+
"black-forest-labs/flux.1-schnell",
|
34 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
35 |
+
],
|
36 |
+
"Fal": [
|
37 |
+
# "fal-ai/aura-flow",
|
38 |
+
# "fal-ai/flux-pro",
|
39 |
+
"fal-ai/fooocus",
|
40 |
+
"fal-ai/kolors",
|
41 |
+
"fal-ai/pixart-sigma",
|
42 |
+
"fal-ai/stable-diffusion-v3-medium",
|
43 |
+
],
|
44 |
},
|
45 |
+
TXT2IMG_DEFAULT_IMAGE_SIZE="square_hd", # fal image sizes
|
46 |
+
TXT2IMG_IMAGE_SIZES=[
|
47 |
+
"landscape_16_9",
|
48 |
+
"landscape_4_3",
|
49 |
+
"square_hd",
|
50 |
+
"portrait_4_3",
|
51 |
+
"portrait_16_9",
|
52 |
+
],
|
53 |
+
TXT2IMG_DEFAULT_ASPECT_RATIO="1024x1024", # fooocus aspect ratios
|
54 |
+
TXT2IMG_ASPECT_RATIOS=[
|
55 |
+
"704x1408",
|
56 |
+
"704x1344",
|
57 |
+
"768x1344",
|
58 |
+
"768x1280",
|
59 |
+
"832x1216",
|
60 |
+
"832x1152",
|
61 |
+
"896x1152",
|
62 |
+
"896x1088",
|
63 |
+
"960x1088",
|
64 |
+
"960x1024",
|
65 |
+
"1024x1024",
|
66 |
+
"1024x960",
|
67 |
+
"1088x960",
|
68 |
+
"1088x896",
|
69 |
+
"1152x896",
|
70 |
+
"1152x832",
|
71 |
+
"1216x832",
|
72 |
+
"1280x768",
|
73 |
+
"1344x768",
|
74 |
+
"1344x704",
|
75 |
+
"1408x704",
|
76 |
+
"1472x704",
|
77 |
+
"1536x640",
|
78 |
+
"1600x640",
|
79 |
+
"1664x576",
|
80 |
+
"1728x576",
|
81 |
+
],
|
82 |
TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
|
83 |
TXT2TXT_DEFAULT_MODEL={
|
84 |
"Huggingface": 4,
|
lib/presets.py
CHANGED
@@ -19,18 +19,40 @@ ServicePresets = SimpleNamespace(
|
|
19 |
|
20 |
# txt2img models
|
21 |
ModelPresets = SimpleNamespace(
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
"name": "FLUX.1 Dev",
|
24 |
-
"num_inference_steps":
|
25 |
"num_inference_steps_min": 10,
|
26 |
-
"num_inference_steps_max":
|
27 |
"guidance_scale": 3.5,
|
28 |
"guidance_scale_min": 1.0,
|
29 |
-
"guidance_scale_max":
|
30 |
"parameters": ["width", "height", "guidance_scale", "num_inference_steps"],
|
31 |
"kwargs": {"max_sequence_length": 512},
|
32 |
},
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
"name": "FLUX.1 Schnell",
|
35 |
"num_inference_steps": 4,
|
36 |
"num_inference_steps_min": 1,
|
@@ -38,14 +60,87 @@ ModelPresets = SimpleNamespace(
|
|
38 |
"parameters": ["width", "height", "num_inference_steps"],
|
39 |
"kwargs": {"guidance_scale": 0.0, "max_sequence_length": 256},
|
40 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
STABLE_DIFFUSION_XL={
|
42 |
"name": "SDXL",
|
43 |
"guidance_scale": 7.0,
|
44 |
"guidance_scale_min": 1.0,
|
45 |
-
"guidance_scale_max":
|
46 |
"num_inference_steps": 40,
|
47 |
"num_inference_steps_min": 10,
|
48 |
"num_inference_steps_max": 50,
|
49 |
-
"parameters": ["
|
50 |
},
|
51 |
)
|
|
|
19 |
|
20 |
# txt2img models
|
21 |
ModelPresets = SimpleNamespace(
|
22 |
+
AURA_FLOW={
|
23 |
+
"name": "AuraFlow",
|
24 |
+
"guidance_scale": 3.5,
|
25 |
+
"guidance_scale_min": 1.0,
|
26 |
+
"guidance_scale_max": 10.0,
|
27 |
+
"num_inference_steps": 50,
|
28 |
+
"num_inference_steps_min": 10,
|
29 |
+
"num_inference_steps_max": 50,
|
30 |
+
"parameters": ["seed", "guidance_scale", "num_inference_steps", "expand_prompt"],
|
31 |
+
"kwargs": {"num_images": 1},
|
32 |
+
},
|
33 |
+
FLUX_DEV={
|
34 |
"name": "FLUX.1 Dev",
|
35 |
+
"num_inference_steps": 28,
|
36 |
"num_inference_steps_min": 10,
|
37 |
+
"num_inference_steps_max": 50,
|
38 |
"guidance_scale": 3.5,
|
39 |
"guidance_scale_min": 1.0,
|
40 |
+
"guidance_scale_max": 10.0,
|
41 |
"parameters": ["width", "height", "guidance_scale", "num_inference_steps"],
|
42 |
"kwargs": {"max_sequence_length": 512},
|
43 |
},
|
44 |
+
FLUX_PRO={
|
45 |
+
"name": "FLUX.1 Pro",
|
46 |
+
"num_inference_steps": 28,
|
47 |
+
"num_inference_steps_min": 10,
|
48 |
+
"num_inference_steps_max": 50,
|
49 |
+
"guidance_scale": 3.5,
|
50 |
+
"guidance_scale_min": 1.0,
|
51 |
+
"guidance_scale_max": 10.0,
|
52 |
+
"parameters": ["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
53 |
+
"kwargs": {"num_images": 1, "sync_mode": True, "safety_tolerance": 6},
|
54 |
+
},
|
55 |
+
FLUX_SCHNELL={
|
56 |
"name": "FLUX.1 Schnell",
|
57 |
"num_inference_steps": 4,
|
58 |
"num_inference_steps_min": 1,
|
|
|
60 |
"parameters": ["width", "height", "num_inference_steps"],
|
61 |
"kwargs": {"guidance_scale": 0.0, "max_sequence_length": 256},
|
62 |
},
|
63 |
+
FOOOCUS={
|
64 |
+
"name": "Fooocus",
|
65 |
+
"guidance_scale": 4.0,
|
66 |
+
"guidance_scale_min": 1.0,
|
67 |
+
"guidance_scale_max": 10.0,
|
68 |
+
"parameters": ["seed", "negative_prompt", "aspect_ratio", "guidance_scale"],
|
69 |
+
"kwargs": {
|
70 |
+
"num_images": 1,
|
71 |
+
"sync_mode": True,
|
72 |
+
"enable_safety_checker": False,
|
73 |
+
"output_format": "png",
|
74 |
+
"sharpness": 2,
|
75 |
+
"styles": ["Fooocus Enhance", "Fooocus V2", "Fooocus Sharp"],
|
76 |
+
"performance": "Quality",
|
77 |
+
},
|
78 |
+
},
|
79 |
+
KOLORS={
|
80 |
+
"name": "Kolors",
|
81 |
+
"guidance_scale": 5.0,
|
82 |
+
"guidance_scale_min": 1.0,
|
83 |
+
"guidance_scale_max": 10.0,
|
84 |
+
"num_inference_steps": 50,
|
85 |
+
"num_inference_steps_min": 10,
|
86 |
+
"num_inference_steps_max": 50,
|
87 |
+
"parameters": [
|
88 |
+
"seed",
|
89 |
+
"negative_prompt",
|
90 |
+
"image_size",
|
91 |
+
"guidance_scale",
|
92 |
+
"num_inference_steps",
|
93 |
+
],
|
94 |
+
"kwargs": {
|
95 |
+
"num_images": 1,
|
96 |
+
"sync_mode": True,
|
97 |
+
"enable_safety_checker": False,
|
98 |
+
"scheduler": "EulerDiscreteScheduler",
|
99 |
+
},
|
100 |
+
},
|
101 |
+
PIXART_SIGMA={
|
102 |
+
"name": "PixArt-Σ",
|
103 |
+
"guidance_scale": 4.5,
|
104 |
+
"guidance_scale_min": 1.0,
|
105 |
+
"guidance_scale_max": 10.0,
|
106 |
+
"num_inference_steps": 35,
|
107 |
+
"num_inference_steps_min": 10,
|
108 |
+
"num_inference_steps_max": 50,
|
109 |
+
"parameters": ["seed", "negative_prompt", "image_size", "guidance_scale", "num_inference_steps"],
|
110 |
+
"kwargs": {
|
111 |
+
"num_images": 1,
|
112 |
+
"sync_mode": True,
|
113 |
+
"enable_safety_checker": False,
|
114 |
+
"style": "(No style)",
|
115 |
+
"scheduler": "SA-SOLVER",
|
116 |
+
},
|
117 |
+
},
|
118 |
+
STABLE_DIFFUSION_3={
|
119 |
+
"name": "SD3",
|
120 |
+
"guidance_scale": 5.0,
|
121 |
+
"guidance_scale_min": 1.0,
|
122 |
+
"guidance_scale_max": 10.0,
|
123 |
+
"num_inference_steps": 28,
|
124 |
+
"num_inference_steps_min": 10,
|
125 |
+
"num_inference_steps_max": 50,
|
126 |
+
"parameters": [
|
127 |
+
"seed",
|
128 |
+
"negative_prompt",
|
129 |
+
"image_size",
|
130 |
+
"guidance_scale",
|
131 |
+
"num_inference_steps",
|
132 |
+
"prompt_expansion",
|
133 |
+
],
|
134 |
+
"kwargs": {"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
|
135 |
+
},
|
136 |
STABLE_DIFFUSION_XL={
|
137 |
"name": "SDXL",
|
138 |
"guidance_scale": 7.0,
|
139 |
"guidance_scale_min": 1.0,
|
140 |
+
"guidance_scale_max": 10.0,
|
141 |
"num_inference_steps": 40,
|
142 |
"num_inference_steps_min": 10,
|
143 |
"num_inference_steps_max": 50,
|
144 |
+
"parameters": ["seed", "negative_prompt", "width", "height", "guidance_scale", "num_inference_steps"],
|
145 |
},
|
146 |
)
|
pages/2_🎨_Text_to_Image.py
CHANGED
@@ -5,16 +5,16 @@ import streamlit as st
|
|
5 |
|
6 |
from lib import Config, ModelPresets, txt2img_generate
|
7 |
|
8 |
-
HF_TOKEN = None
|
9 |
-
FAL_KEY = None
|
10 |
-
# HF_TOKEN = os.environ.get("HF_TOKEN") or None
|
11 |
-
# FAL_KEY = os.environ.get("FAL_KEY") or None
|
12 |
-
API_URL = "https://api-inference.huggingface.co/models"
|
13 |
-
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}", "X-Wait-For-Model": "true", "X-Use-Cache": "false"}
|
14 |
PRESET_MODEL = {
|
15 |
-
"black-forest-labs/flux.1-dev": ModelPresets.
|
16 |
-
"black-forest-labs/flux.1-schnell": ModelPresets.
|
17 |
"stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
|
|
|
|
|
|
|
|
|
18 |
}
|
19 |
|
20 |
# config
|
@@ -34,8 +34,8 @@ if "api_key_huggingface" not in st.session_state:
|
|
34 |
if "txt2img_messages" not in st.session_state:
|
35 |
st.session_state.txt2img_messages = []
|
36 |
|
37 |
-
if "
|
38 |
-
st.session_state.
|
39 |
|
40 |
if "txt2img_seed" not in st.session_state:
|
41 |
st.session_state.txt2img_seed = 0
|
@@ -45,9 +45,9 @@ st.logo("logo.svg")
|
|
45 |
st.sidebar.header("Settings")
|
46 |
service = st.sidebar.selectbox(
|
47 |
"Service",
|
48 |
-
options=["Huggingface"],
|
49 |
-
index=
|
50 |
-
disabled=st.session_state.
|
51 |
)
|
52 |
|
53 |
if service == "Huggingface" and HF_TOKEN is None:
|
@@ -56,7 +56,7 @@ if service == "Huggingface" and HF_TOKEN is None:
|
|
56 |
type="password",
|
57 |
help="Cleared on page refresh",
|
58 |
value=st.session_state.api_key_huggingface,
|
59 |
-
disabled=st.session_state.
|
60 |
)
|
61 |
else:
|
62 |
st.session_state.api_key_huggingface = st.session_state.api_key_huggingface
|
@@ -67,7 +67,7 @@ if service == "Fal" and FAL_KEY is None:
|
|
67 |
type="password",
|
68 |
help="Cleared on page refresh",
|
69 |
value=st.session_state.api_key_fal,
|
70 |
-
disabled=st.session_state.
|
71 |
)
|
72 |
else:
|
73 |
st.session_state.api_key_fal = st.session_state.api_key_fal
|
@@ -80,17 +80,11 @@ if service == "Fal" and FAL_KEY is not None:
|
|
80 |
|
81 |
model = st.sidebar.selectbox(
|
82 |
"Model",
|
83 |
-
options=Config.TXT2IMG_MODELS,
|
84 |
-
index=Config.TXT2IMG_DEFAULT_MODEL,
|
85 |
-
disabled=st.session_state.
|
86 |
format_func=lambda x: x.split("/")[1],
|
87 |
)
|
88 |
-
aspect_ratio = st.sidebar.select_slider(
|
89 |
-
"Aspect Ratio",
|
90 |
-
options=list(Config.TXT2IMG_AR.keys()),
|
91 |
-
value=Config.TXT2IMG_DEFAULT_AR,
|
92 |
-
disabled=st.session_state.txt2img_running,
|
93 |
-
)
|
94 |
|
95 |
# heading
|
96 |
st.html("""
|
@@ -102,10 +96,52 @@ st.html("""
|
|
102 |
parameters = {}
|
103 |
preset = PRESET_MODEL[model]
|
104 |
for param in preset["parameters"]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
if param == "width":
|
106 |
-
parameters[param] =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
if param == "height":
|
108 |
-
parameters[param] =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
if param == "guidance_scale":
|
110 |
parameters[param] = st.sidebar.slider(
|
111 |
"Guidance Scale",
|
@@ -113,7 +149,7 @@ for param in preset["parameters"]:
|
|
113 |
preset["guidance_scale_max"],
|
114 |
preset["guidance_scale"],
|
115 |
0.1,
|
116 |
-
disabled=st.session_state.
|
117 |
)
|
118 |
if param == "num_inference_steps":
|
119 |
parameters[param] = st.sidebar.slider(
|
@@ -122,21 +158,19 @@ for param in preset["parameters"]:
|
|
122 |
preset["num_inference_steps_max"],
|
123 |
preset["num_inference_steps"],
|
124 |
1,
|
125 |
-
disabled=st.session_state.
|
126 |
)
|
127 |
-
if param == "
|
128 |
-
parameters[param] = st.sidebar.
|
129 |
-
"
|
130 |
-
|
131 |
-
|
132 |
-
value=-1,
|
133 |
-
disabled=st.session_state.txt2img_running,
|
134 |
)
|
135 |
-
if param == "
|
136 |
-
parameters[param] = st.sidebar.
|
137 |
-
|
138 |
-
value=
|
139 |
-
disabled=st.session_state.
|
140 |
)
|
141 |
|
142 |
# wrap the prompt in an expander to display additional parameters
|
@@ -154,8 +188,13 @@ for message in st.session_state.txt2img_messages:
|
|
154 |
div[data-testid="stMarkdownContainer"] p:not(:last-of-type) { margin-bottom: 0 }
|
155 |
</style>
|
156 |
""")
|
|
|
|
|
|
|
|
|
|
|
157 |
md = f"`model`: {message['model']}\n\n"
|
158 |
-
md += "\n\n".join([f"`{k}`: {v}" for k, v in
|
159 |
st.markdown(md)
|
160 |
|
161 |
if role == "assistant":
|
@@ -192,7 +231,7 @@ if st.session_state.txt2img_messages:
|
|
192 |
col1, col2 = st.columns(2)
|
193 |
with col1:
|
194 |
if (
|
195 |
-
st.button("❌", help="Delete last generation", disabled=st.session_state.
|
196 |
and len(st.session_state.txt2img_messages) >= 2
|
197 |
):
|
198 |
st.session_state.txt2img_messages.pop()
|
@@ -200,7 +239,7 @@ if st.session_state.txt2img_messages:
|
|
200 |
st.rerun()
|
201 |
|
202 |
with col2:
|
203 |
-
if st.button("🗑️", help="Clear all generations", disabled=st.session_state.
|
204 |
st.session_state.txt2img_messages = []
|
205 |
st.session_state.txt2img_seed = 0
|
206 |
st.rerun()
|
@@ -210,7 +249,7 @@ else:
|
|
210 |
# show the prompt and spinner while loading then update state and re-render
|
211 |
if prompt := st.chat_input(
|
212 |
"What do you want to see?",
|
213 |
-
on_submit=lambda: setattr(st.session_state, "
|
214 |
):
|
215 |
if "seed" in parameters and parameters["seed"] >= 0:
|
216 |
st.session_state.txt2img_seed = parameters["seed"]
|
@@ -231,7 +270,7 @@ if prompt := st.chat_input(
|
|
231 |
parameters.update(preset["kwargs"])
|
232 |
api_key = getattr(st.session_state, f"api_key_{service.lower()}", None)
|
233 |
image = txt2img_generate(api_key, service, model, prompt, parameters)
|
234 |
-
st.session_state.
|
235 |
|
236 |
model_name = PRESET_MODEL[model]["name"]
|
237 |
st.session_state.txt2img_messages.append(
|
|
|
5 |
|
6 |
from lib import Config, ModelPresets, txt2img_generate
|
7 |
|
8 |
+
HF_TOKEN = os.environ.get("HF_TOKEN") or None
|
9 |
+
FAL_KEY = os.environ.get("FAL_KEY") or None
|
|
|
|
|
|
|
|
|
10 |
PRESET_MODEL = {
|
11 |
+
"black-forest-labs/flux.1-dev": ModelPresets.FLUX_DEV,
|
12 |
+
"black-forest-labs/flux.1-schnell": ModelPresets.FLUX_SCHNELL,
|
13 |
"stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
|
14 |
+
"fal-ai/fooocus": ModelPresets.FOOOCUS,
|
15 |
+
"fal-ai/kolors": ModelPresets.KOLORS,
|
16 |
+
"fal-ai/pixart-sigma": ModelPresets.PIXART_SIGMA,
|
17 |
+
"fal-ai/stable-diffusion-v3-medium": ModelPresets.STABLE_DIFFUSION_3,
|
18 |
}
|
19 |
|
20 |
# config
|
|
|
34 |
if "txt2img_messages" not in st.session_state:
|
35 |
st.session_state.txt2img_messages = []
|
36 |
|
37 |
+
if "running" not in st.session_state:
|
38 |
+
st.session_state.running = False
|
39 |
|
40 |
if "txt2img_seed" not in st.session_state:
|
41 |
st.session_state.txt2img_seed = 0
|
|
|
45 |
st.sidebar.header("Settings")
|
46 |
service = st.sidebar.selectbox(
|
47 |
"Service",
|
48 |
+
options=["Fal", "Huggingface"],
|
49 |
+
index=1,
|
50 |
+
disabled=st.session_state.running,
|
51 |
)
|
52 |
|
53 |
if service == "Huggingface" and HF_TOKEN is None:
|
|
|
56 |
type="password",
|
57 |
help="Cleared on page refresh",
|
58 |
value=st.session_state.api_key_huggingface,
|
59 |
+
disabled=st.session_state.running,
|
60 |
)
|
61 |
else:
|
62 |
st.session_state.api_key_huggingface = st.session_state.api_key_huggingface
|
|
|
67 |
type="password",
|
68 |
help="Cleared on page refresh",
|
69 |
value=st.session_state.api_key_fal,
|
70 |
+
disabled=st.session_state.running,
|
71 |
)
|
72 |
else:
|
73 |
st.session_state.api_key_fal = st.session_state.api_key_fal
|
|
|
80 |
|
81 |
model = st.sidebar.selectbox(
|
82 |
"Model",
|
83 |
+
options=Config.TXT2IMG_MODELS[service],
|
84 |
+
index=Config.TXT2IMG_DEFAULT_MODEL[service],
|
85 |
+
disabled=st.session_state.running,
|
86 |
format_func=lambda x: x.split("/")[1],
|
87 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# heading
|
90 |
st.html("""
|
|
|
96 |
parameters = {}
|
97 |
preset = PRESET_MODEL[model]
|
98 |
for param in preset["parameters"]:
|
99 |
+
if param == "seed":
|
100 |
+
parameters[param] = st.sidebar.number_input(
|
101 |
+
"Seed",
|
102 |
+
min_value=-1,
|
103 |
+
max_value=(1 << 53) - 1,
|
104 |
+
value=-1,
|
105 |
+
disabled=st.session_state.running,
|
106 |
+
)
|
107 |
+
if param == "negative_prompt":
|
108 |
+
parameters[param] = st.sidebar.text_area(
|
109 |
+
"Negative Prompt",
|
110 |
+
value=Config.TXT2IMG_NEGATIVE_PROMPT,
|
111 |
+
disabled=st.session_state.running,
|
112 |
+
)
|
113 |
if param == "width":
|
114 |
+
parameters[param] = st.sidebar.slider(
|
115 |
+
"Width",
|
116 |
+
step=64,
|
117 |
+
value=1024,
|
118 |
+
min_value=512,
|
119 |
+
max_value=2048,
|
120 |
+
disabled=st.session_state.running,
|
121 |
+
)
|
122 |
if param == "height":
|
123 |
+
parameters[param] = st.sidebar.slider(
|
124 |
+
"Height",
|
125 |
+
step=64,
|
126 |
+
value=1024,
|
127 |
+
min_value=512,
|
128 |
+
max_value=2048,
|
129 |
+
disabled=st.session_state.running,
|
130 |
+
)
|
131 |
+
if param == "image_size":
|
132 |
+
parameters[param] = st.sidebar.select_slider(
|
133 |
+
"Image Size",
|
134 |
+
options=Config.TXT2IMG_IMAGE_SIZES,
|
135 |
+
value=Config.TXT2IMG_DEFAULT_IMAGE_SIZE,
|
136 |
+
disabled=st.session_state.running,
|
137 |
+
)
|
138 |
+
if param == "aspect_ratio":
|
139 |
+
parameters[param] = st.sidebar.select_slider(
|
140 |
+
"Aspect Ratio",
|
141 |
+
options=Config.TXT2IMG_ASPECT_RATIOS,
|
142 |
+
value=Config.TXT2IMG_DEFAULT_ASPECT_RATIO,
|
143 |
+
disabled=st.session_state.running,
|
144 |
+
)
|
145 |
if param == "guidance_scale":
|
146 |
parameters[param] = st.sidebar.slider(
|
147 |
"Guidance Scale",
|
|
|
149 |
preset["guidance_scale_max"],
|
150 |
preset["guidance_scale"],
|
151 |
0.1,
|
152 |
+
disabled=st.session_state.running,
|
153 |
)
|
154 |
if param == "num_inference_steps":
|
155 |
parameters[param] = st.sidebar.slider(
|
|
|
158 |
preset["num_inference_steps_max"],
|
159 |
preset["num_inference_steps"],
|
160 |
1,
|
161 |
+
disabled=st.session_state.running,
|
162 |
)
|
163 |
+
if param == "expand_prompt":
|
164 |
+
parameters[param] = st.sidebar.checkbox(
|
165 |
+
"Expand Prompt",
|
166 |
+
value=False,
|
167 |
+
disabled=st.session_state.running,
|
|
|
|
|
168 |
)
|
169 |
+
if param == "prompt_expansion":
|
170 |
+
parameters[param] = st.sidebar.checkbox(
|
171 |
+
"Prompt Expansion",
|
172 |
+
value=False,
|
173 |
+
disabled=st.session_state.running,
|
174 |
)
|
175 |
|
176 |
# wrap the prompt in an expander to display additional parameters
|
|
|
188 |
div[data-testid="stMarkdownContainer"] p:not(:last-of-type) { margin-bottom: 0 }
|
189 |
</style>
|
190 |
""")
|
191 |
+
filtered_parameters = {
|
192 |
+
k: v
|
193 |
+
for k, v in message["parameters"].items()
|
194 |
+
if k not in Config.TXT2IMG_HIDDEN_PARAMETERS
|
195 |
+
}
|
196 |
md = f"`model`: {message['model']}\n\n"
|
197 |
+
md += "\n\n".join([f"`{k}`: {v}" for k, v in filtered_parameters.items()])
|
198 |
st.markdown(md)
|
199 |
|
200 |
if role == "assistant":
|
|
|
231 |
col1, col2 = st.columns(2)
|
232 |
with col1:
|
233 |
if (
|
234 |
+
st.button("❌", help="Delete last generation", disabled=st.session_state.running)
|
235 |
and len(st.session_state.txt2img_messages) >= 2
|
236 |
):
|
237 |
st.session_state.txt2img_messages.pop()
|
|
|
239 |
st.rerun()
|
240 |
|
241 |
with col2:
|
242 |
+
if st.button("🗑️", help="Clear all generations", disabled=st.session_state.running):
|
243 |
st.session_state.txt2img_messages = []
|
244 |
st.session_state.txt2img_seed = 0
|
245 |
st.rerun()
|
|
|
249 |
# show the prompt and spinner while loading then update state and re-render
|
250 |
if prompt := st.chat_input(
|
251 |
"What do you want to see?",
|
252 |
+
on_submit=lambda: setattr(st.session_state, "running", True),
|
253 |
):
|
254 |
if "seed" in parameters and parameters["seed"] >= 0:
|
255 |
st.session_state.txt2img_seed = parameters["seed"]
|
|
|
270 |
parameters.update(preset["kwargs"])
|
271 |
api_key = getattr(st.session_state, f"api_key_{service.lower()}", None)
|
272 |
image = txt2img_generate(api_key, service, model, prompt, parameters)
|
273 |
+
st.session_state.running = False
|
274 |
|
275 |
model_name = PRESET_MODEL[model]["name"]
|
276 |
st.session_state.txt2img_messages.append(
|