adamelliotfields
commited on
Commit
β’
c659a88
1
Parent(s):
7f5047d
Refactor preset
Browse files- 0_π _Home.py +4 -3
- lib/__init__.py +3 -1
- lib/api.py +18 -28
- lib/config.py +37 -53
- lib/preset.py +93 -113
- pages/1_π¬_Text_Generation.py +32 -36
- pages/2_π¨_Text_to_Image.py +28 -40
0_π _Home.py
CHANGED
@@ -4,12 +4,12 @@ from lib import config
|
|
4 |
|
5 |
st.set_page_config(
|
6 |
page_title=config.title,
|
7 |
-
page_icon=config.
|
8 |
layout=config.layout,
|
9 |
)
|
10 |
|
11 |
# sidebar
|
12 |
-
st.logo(
|
13 |
|
14 |
# title
|
15 |
st.html("""
|
@@ -39,6 +39,7 @@ st.html("""
|
|
39 |
<h1>API Inference</h1>
|
40 |
<span class="pro-badge">PRO</span>
|
41 |
</div>
|
|
|
42 |
""")
|
43 |
|
44 |
st.markdown("## Tasks")
|
@@ -58,7 +59,7 @@ st.markdown("""
|
|
58 |
st.markdown("""
|
59 |
## Usage
|
60 |
|
61 |
-
Choose a task
|
62 |
|
63 |
I recommend [duplicating this space](https://huggingface.co/spaces/adamelliotfields/api-inference?duplicate=true) **privately** and persisting your keys as secrets. See [`README.md`](https://huggingface.co/spaces/adamelliotfields/api-inference/blob/main/README.md).
|
64 |
""")
|
|
|
4 |
|
5 |
st.set_page_config(
|
6 |
page_title=config.title,
|
7 |
+
page_icon=config.logo,
|
8 |
layout=config.layout,
|
9 |
)
|
10 |
|
11 |
# sidebar
|
12 |
+
st.logo(config.logo)
|
13 |
|
14 |
# title
|
15 |
st.html("""
|
|
|
39 |
<h1>API Inference</h1>
|
40 |
<span class="pro-badge">PRO</span>
|
41 |
</div>
|
42 |
+
<p>Explore popular AI endpoints in one place.</p>
|
43 |
""")
|
44 |
|
45 |
st.markdown("## Tasks")
|
|
|
59 |
st.markdown("""
|
60 |
## Usage
|
61 |
|
62 |
+
Choose a task. Select a service. Enter your API key (refresh browser to clear).
|
63 |
|
64 |
I recommend [duplicating this space](https://huggingface.co/spaces/adamelliotfields/api-inference?duplicate=true) **privately** and persisting your keys as secrets. See [`README.md`](https://huggingface.co/spaces/adamelliotfields/api-inference/blob/main/README.md).
|
65 |
""")
|
lib/__init__.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
from .api import txt2img_generate, txt2txt_generate
|
2 |
from .config import config
|
3 |
-
from .preset import preset
|
4 |
|
5 |
__all__ = [
|
|
|
|
|
6 |
"config",
|
7 |
"preset",
|
8 |
"txt2img_generate",
|
|
|
1 |
from .api import txt2img_generate, txt2txt_generate
|
2 |
from .config import config
|
3 |
+
from .preset import Txt2ImgPreset, Txt2TxtPreset, preset
|
4 |
|
5 |
__all__ = [
|
6 |
+
"Txt2ImgPreset",
|
7 |
+
"Txt2TxtPreset",
|
8 |
"config",
|
9 |
"preset",
|
10 |
"txt2img_generate",
|
lib/api.py
CHANGED
@@ -11,8 +11,8 @@ from .config import config
|
|
11 |
|
12 |
|
13 |
def txt2txt_generate(api_key, service, model, parameters, **kwargs):
|
14 |
-
base_url = config.
|
15 |
-
if service == "
|
16 |
base_url = f"{base_url}/{model}/v1"
|
17 |
client = OpenAI(api_key=api_key, base_url=base_url)
|
18 |
|
@@ -29,42 +29,32 @@ def txt2txt_generate(api_key, service, model, parameters, **kwargs):
|
|
29 |
|
30 |
def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
31 |
headers = {}
|
32 |
-
|
|
|
|
|
33 |
headers["x-key"] = api_key
|
|
|
34 |
|
35 |
-
if service == "
|
36 |
headers["Authorization"] = f"Key {api_key}"
|
|
|
37 |
|
38 |
-
if service == "
|
39 |
headers["Authorization"] = f"Bearer {api_key}"
|
40 |
headers["X-Wait-For-Model"] = "true"
|
41 |
headers["X-Use-Cache"] = "false"
|
42 |
-
|
43 |
-
if service == "Together":
|
44 |
-
headers["Authorization"] = f"Bearer {api_key}"
|
45 |
-
|
46 |
-
json = {}
|
47 |
-
if service == "Black Forest Labs":
|
48 |
-
json = {**parameters, **kwargs}
|
49 |
-
json["prompt"] = inputs
|
50 |
-
|
51 |
-
if service == "Fal":
|
52 |
-
json = {**parameters, **kwargs}
|
53 |
-
json["prompt"] = inputs
|
54 |
-
|
55 |
-
if service == "Hugging Face":
|
56 |
json = {
|
57 |
"inputs": inputs,
|
58 |
"parameters": {**parameters, **kwargs},
|
59 |
}
|
60 |
|
61 |
-
if service == "
|
62 |
-
|
63 |
json["prompt"] = inputs
|
64 |
|
65 |
-
base_url = config.
|
66 |
|
67 |
-
if service not in ["
|
68 |
base_url = f"{base_url}/{model}"
|
69 |
|
70 |
try:
|
@@ -72,9 +62,9 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
72 |
if response.status_code // 100 == 2: # 2xx
|
73 |
# BFL is async so we need to poll for result
|
74 |
# https://api.bfl.ml/docs
|
75 |
-
if service == "
|
76 |
id = response.json()["id"]
|
77 |
-
url = f"{config.
|
78 |
|
79 |
retries = 0
|
80 |
while retries < config.txt2img.timeout:
|
@@ -95,7 +85,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
95 |
|
96 |
return "Error: API timeout"
|
97 |
|
98 |
-
if service == "
|
99 |
# Sync mode means wait for image base64 string instead of CDN link
|
100 |
if parameters.get("sync_mode", True):
|
101 |
bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
|
@@ -105,10 +95,10 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
105 |
image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
|
106 |
return Image.open(io.BytesIO(image.content))
|
107 |
|
108 |
-
if service == "
|
109 |
return Image.open(io.BytesIO(response.content))
|
110 |
|
111 |
-
if service == "
|
112 |
url = response.json()["data"][0]["url"]
|
113 |
image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
|
114 |
return Image.open(io.BytesIO(image.content))
|
|
|
11 |
|
12 |
|
13 |
def txt2txt_generate(api_key, service, model, parameters, **kwargs):
|
14 |
+
base_url = config.service[service].url
|
15 |
+
if service == "hf":
|
16 |
base_url = f"{base_url}/{model}/v1"
|
17 |
client = OpenAI(api_key=api_key, base_url=base_url)
|
18 |
|
|
|
29 |
|
30 |
def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
31 |
headers = {}
|
32 |
+
json = {**parameters, **kwargs}
|
33 |
+
|
34 |
+
if service == "bfl":
|
35 |
headers["x-key"] = api_key
|
36 |
+
json["prompt"] = inputs
|
37 |
|
38 |
+
if service == "fal":
|
39 |
headers["Authorization"] = f"Key {api_key}"
|
40 |
+
json["prompt"] = inputs
|
41 |
|
42 |
+
if service == "hf":
|
43 |
headers["Authorization"] = f"Bearer {api_key}"
|
44 |
headers["X-Wait-For-Model"] = "true"
|
45 |
headers["X-Use-Cache"] = "false"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
json = {
|
47 |
"inputs": inputs,
|
48 |
"parameters": {**parameters, **kwargs},
|
49 |
}
|
50 |
|
51 |
+
if service == "together":
|
52 |
+
headers["Authorization"] = f"Bearer {api_key}"
|
53 |
json["prompt"] = inputs
|
54 |
|
55 |
+
base_url = config.service[service].url
|
56 |
|
57 |
+
if service not in ["together"]:
|
58 |
base_url = f"{base_url}/{model}"
|
59 |
|
60 |
try:
|
|
|
62 |
if response.status_code // 100 == 2: # 2xx
|
63 |
# BFL is async so we need to poll for result
|
64 |
# https://api.bfl.ml/docs
|
65 |
+
if service == "bfl":
|
66 |
id = response.json()["id"]
|
67 |
+
url = f"{config.service[service].url}/get_result?id={id}"
|
68 |
|
69 |
retries = 0
|
70 |
while retries < config.txt2img.timeout:
|
|
|
85 |
|
86 |
return "Error: API timeout"
|
87 |
|
88 |
+
if service == "fal":
|
89 |
# Sync mode means wait for image base64 string instead of CDN link
|
90 |
if parameters.get("sync_mode", True):
|
91 |
bytes = base64.b64decode(response.json()["images"][0]["url"].split(",")[-1])
|
|
|
95 |
image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
|
96 |
return Image.open(io.BytesIO(image.content))
|
97 |
|
98 |
+
if service == "hf":
|
99 |
return Image.open(io.BytesIO(response.content))
|
100 |
|
101 |
+
if service == "together":
|
102 |
url = response.json()["data"][0]["url"]
|
103 |
image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
|
104 |
return Image.open(io.BytesIO(image.content))
|
lib/config.py
CHANGED
@@ -1,31 +1,22 @@
|
|
|
|
1 |
from dataclasses import dataclass
|
2 |
-
from typing import Dict, List
|
3 |
|
4 |
-
from .preset import preset
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
model_id = p.model_id
|
12 |
-
if service not in models:
|
13 |
-
models[service] = []
|
14 |
-
models[service].append(model_id)
|
15 |
-
return models
|
16 |
|
17 |
|
18 |
@dataclass
|
19 |
class Txt2TxtConfig:
|
20 |
default_system: str
|
21 |
-
default_model: Dict[str, int]
|
22 |
-
models: Dict[str, List[str]]
|
23 |
|
24 |
|
25 |
@dataclass
|
26 |
class Txt2ImgConfig:
|
27 |
-
default_model: Dict[str, int]
|
28 |
-
models: Dict[str, List[str]]
|
29 |
hidden_parameters: List[str]
|
30 |
negative_prompt: str
|
31 |
default_image_size: str
|
@@ -38,37 +29,50 @@ class Txt2ImgConfig:
|
|
38 |
@dataclass
|
39 |
class Config:
|
40 |
title: str
|
41 |
-
icon: str
|
42 |
layout: str
|
43 |
-
|
|
|
44 |
txt2img: Txt2ImgConfig
|
45 |
txt2txt: Txt2TxtConfig
|
46 |
|
47 |
|
48 |
-
# TODO: API keys should be with services (make a dataclass)
|
49 |
config = Config(
|
50 |
title="API Inference",
|
51 |
-
icon="β‘",
|
52 |
layout="wide",
|
53 |
-
|
54 |
-
|
55 |
-
"
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
},
|
60 |
txt2img=Txt2ImgConfig(
|
61 |
-
default_model={
|
62 |
-
"Black Forest Labs": 2,
|
63 |
-
"Fal": 0,
|
64 |
-
"Hugging Face": 2,
|
65 |
-
"Together": 0,
|
66 |
-
},
|
67 |
-
models=txt2img_models_from_presets(preset.txt2img.presets),
|
68 |
hidden_parameters=[
|
69 |
# Sent to API but not shown in generation parameters accordion
|
70 |
"enable_safety_checker",
|
71 |
"max_sequence_length",
|
|
|
72 |
"num_images",
|
73 |
"output_format",
|
74 |
"performance",
|
@@ -115,25 +119,5 @@ config = Config(
|
|
115 |
),
|
116 |
txt2txt=Txt2TxtConfig(
|
117 |
default_system="You are a helpful assistant. Be precise and concise.",
|
118 |
-
default_model={
|
119 |
-
"Hugging Face": 4,
|
120 |
-
"Perplexity": 3,
|
121 |
-
},
|
122 |
-
models={
|
123 |
-
"Hugging Face": [
|
124 |
-
"codellama/codellama-34b-instruct-hf",
|
125 |
-
"meta-llama/llama-2-13b-chat-hf",
|
126 |
-
"meta-llama/meta-llama-3.1-405b-instruct-fp8",
|
127 |
-
"mistralai/mistral-7b-instruct-v0.2",
|
128 |
-
"nousresearch/nous-hermes-2-mixtral-8x7b-dpo",
|
129 |
-
],
|
130 |
-
"Perplexity": [
|
131 |
-
"llama-3.1-sonar-small-128k-chat",
|
132 |
-
"llama-3.1-sonar-large-128k-chat",
|
133 |
-
"llama-3.1-sonar-small-128k-online",
|
134 |
-
"llama-3.1-sonar-large-128k-online",
|
135 |
-
"llama-3.1-sonar-huge-128k-online",
|
136 |
-
],
|
137 |
-
},
|
138 |
),
|
139 |
)
|
|
|
1 |
+
import os
|
2 |
from dataclasses import dataclass
|
3 |
+
from typing import Dict, List, Optional
|
4 |
|
|
|
5 |
|
6 |
+
@dataclass
|
7 |
+
class ServiceConfig:
|
8 |
+
name: str
|
9 |
+
url: str
|
10 |
+
api_key: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
@dataclass
|
14 |
class Txt2TxtConfig:
|
15 |
default_system: str
|
|
|
|
|
16 |
|
17 |
|
18 |
@dataclass
|
19 |
class Txt2ImgConfig:
|
|
|
|
|
20 |
hidden_parameters: List[str]
|
21 |
negative_prompt: str
|
22 |
default_image_size: str
|
|
|
29 |
@dataclass
|
30 |
class Config:
|
31 |
title: str
|
|
|
32 |
layout: str
|
33 |
+
logo: str
|
34 |
+
service: Dict[str, ServiceConfig]
|
35 |
txt2img: Txt2ImgConfig
|
36 |
txt2txt: Txt2TxtConfig
|
37 |
|
38 |
|
|
|
39 |
config = Config(
|
40 |
title="API Inference",
|
|
|
41 |
layout="wide",
|
42 |
+
logo="logo.png",
|
43 |
+
service={
|
44 |
+
"bfl": ServiceConfig(
|
45 |
+
"Black Forest Labs",
|
46 |
+
"https://api.bfl.ml/v1",
|
47 |
+
os.environ.get("BFL_API_KEY"),
|
48 |
+
),
|
49 |
+
"fal": ServiceConfig(
|
50 |
+
"Fal",
|
51 |
+
"https://fal.run",
|
52 |
+
os.environ.get("FAL_KEY"),
|
53 |
+
),
|
54 |
+
"hf": ServiceConfig(
|
55 |
+
"Hugging Face",
|
56 |
+
"https://api-inference.huggingface.co/models",
|
57 |
+
os.environ.get("HF_TOKEN"),
|
58 |
+
),
|
59 |
+
"pplx": ServiceConfig(
|
60 |
+
"Perplexity",
|
61 |
+
"https://api.perplexity.ai",
|
62 |
+
os.environ.get("PPLX_API_KEY"),
|
63 |
+
),
|
64 |
+
"together": ServiceConfig(
|
65 |
+
"Together",
|
66 |
+
"https://api.together.xyz/v1/images/generations",
|
67 |
+
os.environ.get("TOGETHER_API_KEY"),
|
68 |
+
),
|
69 |
},
|
70 |
txt2img=Txt2ImgConfig(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
hidden_parameters=[
|
72 |
# Sent to API but not shown in generation parameters accordion
|
73 |
"enable_safety_checker",
|
74 |
"max_sequence_length",
|
75 |
+
"n",
|
76 |
"num_images",
|
77 |
"output_format",
|
78 |
"performance",
|
|
|
119 |
),
|
120 |
txt2txt=Txt2TxtConfig(
|
121 |
default_system="You are a helpful assistant. Be precise and concise.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
),
|
123 |
)
|
lib/preset.py
CHANGED
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Union
|
|
4 |
|
5 |
@dataclass
|
6 |
class Txt2TxtPreset:
|
|
|
7 |
frequency_penalty: float
|
8 |
frequency_penalty_min: float
|
9 |
frequency_penalty_max: float
|
@@ -12,10 +13,7 @@ class Txt2TxtPreset:
|
|
12 |
|
13 |
@dataclass
|
14 |
class Txt2ImgPreset:
|
15 |
-
# FLUX1.1 has no scale or steps
|
16 |
name: str
|
17 |
-
service: str
|
18 |
-
model_id: str
|
19 |
guidance_scale: Optional[float] = None
|
20 |
guidance_scale_min: Optional[float] = None
|
21 |
guidance_scale_max: Optional[float] = None
|
@@ -27,66 +25,55 @@ class Txt2ImgPreset:
|
|
27 |
|
28 |
|
29 |
@dataclass
|
30 |
-
class
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
@dataclass
|
36 |
-
class Txt2ImgPresets:
|
37 |
-
presets: List[Txt2ImgPreset] = field(default_factory=list)
|
38 |
|
39 |
-
def __iter__(self):
|
40 |
-
return iter(self.presets)
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
47 |
|
48 |
|
49 |
preset = Preset(
|
50 |
-
txt2txt=
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
frequency_penalty_min=1.0,
|
60 |
-
frequency_penalty_max=2.0,
|
61 |
-
parameters=["max_tokens", "temperature", "frequency_penalty"],
|
62 |
-
),
|
63 |
-
),
|
64 |
-
txt2img=Txt2ImgPresets(
|
65 |
-
presets=[
|
66 |
-
Txt2ImgPreset(
|
67 |
-
"AuraFlow",
|
68 |
-
"Fal",
|
69 |
-
"fal-ai/aura-flow",
|
70 |
-
guidance_scale=3.5,
|
71 |
-
guidance_scale_min=1.0,
|
72 |
-
guidance_scale_max=10.0,
|
73 |
-
num_inference_steps=28,
|
74 |
-
num_inference_steps_min=10,
|
75 |
-
num_inference_steps_max=50,
|
76 |
-
parameters=["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
|
77 |
-
kwargs={"num_images": 1, "sync_mode": False},
|
78 |
),
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
"FLUX1.1 Pro",
|
81 |
-
"Black Forest Labs",
|
82 |
-
"flux-pro-1.1",
|
83 |
parameters=["seed", "width", "height", "prompt_upsampling"],
|
84 |
kwargs={"safety_tolerance": 6},
|
85 |
),
|
86 |
-
Txt2ImgPreset(
|
87 |
"FLUX.1 Pro",
|
88 |
-
"Black Forest Labs",
|
89 |
-
"flux-pro",
|
90 |
guidance_scale=2.5,
|
91 |
guidance_scale_min=1.5,
|
92 |
guidance_scale_max=5.0,
|
@@ -96,10 +83,8 @@ preset = Preset(
|
|
96 |
parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
|
97 |
kwargs={"safety_tolerance": 6, "interval": 1},
|
98 |
),
|
99 |
-
Txt2ImgPreset(
|
100 |
"FLUX.1 Dev",
|
101 |
-
"Black Forest Labs",
|
102 |
-
"flux-dev",
|
103 |
num_inference_steps=28,
|
104 |
num_inference_steps_min=10,
|
105 |
num_inference_steps_max=50,
|
@@ -109,10 +94,21 @@ preset = Preset(
|
|
109 |
parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
|
110 |
kwargs={"safety_tolerance": 6},
|
111 |
),
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
"FLUX1.1 Pro",
|
114 |
-
"Fal",
|
115 |
-
"fal-ai/flux-pro/v1.1",
|
116 |
parameters=["seed", "image_size"],
|
117 |
kwargs={
|
118 |
"num_images": 1,
|
@@ -121,10 +117,8 @@ preset = Preset(
|
|
121 |
"enable_safety_checker": False,
|
122 |
},
|
123 |
),
|
124 |
-
Txt2ImgPreset(
|
125 |
"FLUX.1 Pro",
|
126 |
-
"Fal",
|
127 |
-
"fal-ai/flux-pro",
|
128 |
guidance_scale=2.5,
|
129 |
guidance_scale_min=1.5,
|
130 |
guidance_scale_max=5.0,
|
@@ -134,10 +128,8 @@ preset = Preset(
|
|
134 |
parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
135 |
kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
|
136 |
),
|
137 |
-
Txt2ImgPreset(
|
138 |
"FLUX.1 Dev",
|
139 |
-
"Fal",
|
140 |
-
"fal-ai/flux/dev",
|
141 |
num_inference_steps=28,
|
142 |
num_inference_steps_min=10,
|
143 |
num_inference_steps_max=50,
|
@@ -147,53 +139,16 @@ preset = Preset(
|
|
147 |
parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
148 |
kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
|
149 |
),
|
150 |
-
Txt2ImgPreset(
|
151 |
"FLUX.1 Schnell",
|
152 |
-
"Fal",
|
153 |
-
"fal-ai/flux/schnell",
|
154 |
num_inference_steps=4,
|
155 |
num_inference_steps_min=1,
|
156 |
num_inference_steps_max=12,
|
157 |
parameters=["seed", "image_size", "num_inference_steps"],
|
158 |
kwargs={"num_images": 1, "sync_mode": False, "enable_safety_checker": False},
|
159 |
),
|
160 |
-
Txt2ImgPreset(
|
161 |
-
"FLUX.1 Dev",
|
162 |
-
"Hugging Face",
|
163 |
-
"black-forest-labs/flux.1-dev",
|
164 |
-
num_inference_steps=28,
|
165 |
-
num_inference_steps_min=10,
|
166 |
-
num_inference_steps_max=50,
|
167 |
-
guidance_scale=3.0,
|
168 |
-
guidance_scale_min=1.5,
|
169 |
-
guidance_scale_max=5.0,
|
170 |
-
parameters=["width", "height", "guidance_scale", "num_inference_steps"],
|
171 |
-
kwargs={"max_sequence_length": 512},
|
172 |
-
),
|
173 |
-
Txt2ImgPreset(
|
174 |
-
"FLUX.1 Schnell",
|
175 |
-
"Hugging Face",
|
176 |
-
"black-forest-labs/flux.1-schnell",
|
177 |
-
num_inference_steps=4,
|
178 |
-
num_inference_steps_min=1,
|
179 |
-
num_inference_steps_max=12,
|
180 |
-
parameters=["width", "height", "num_inference_steps"],
|
181 |
-
kwargs={"guidance_scale": 0.0, "max_sequence_length": 256},
|
182 |
-
),
|
183 |
-
Txt2ImgPreset(
|
184 |
-
"FLUX.1 Schnell Free",
|
185 |
-
"Together",
|
186 |
-
"black-forest-labs/FLUX.1-schnell-Free",
|
187 |
-
num_inference_steps=4,
|
188 |
-
num_inference_steps_min=1,
|
189 |
-
num_inference_steps_max=12,
|
190 |
-
parameters=["model", "seed", "width", "height", "steps"],
|
191 |
-
kwargs={"n": 1},
|
192 |
-
),
|
193 |
-
Txt2ImgPreset(
|
194 |
"Fooocus",
|
195 |
-
"Fal",
|
196 |
-
"fal-ai/fooocus",
|
197 |
guidance_scale=4.0,
|
198 |
guidance_scale_min=1.0,
|
199 |
guidance_scale_max=10.0,
|
@@ -208,10 +163,8 @@ preset = Preset(
|
|
208 |
"performance": "Quality",
|
209 |
},
|
210 |
),
|
211 |
-
Txt2ImgPreset(
|
212 |
"Kolors",
|
213 |
-
"Fal",
|
214 |
-
"fal-ai/kolors",
|
215 |
guidance_scale=5.0,
|
216 |
guidance_scale_min=1.0,
|
217 |
guidance_scale_max=10.0,
|
@@ -226,10 +179,8 @@ preset = Preset(
|
|
226 |
"scheduler": "EulerDiscreteScheduler",
|
227 |
},
|
228 |
),
|
229 |
-
Txt2ImgPreset(
|
230 |
"SD3",
|
231 |
-
"Fal",
|
232 |
-
"fal-ai/stable-diffusion-v3-medium",
|
233 |
guidance_scale=5.0,
|
234 |
guidance_scale_min=1.0,
|
235 |
guidance_scale_max=10.0,
|
@@ -246,10 +197,29 @@ preset = Preset(
|
|
246 |
],
|
247 |
kwargs={"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
|
248 |
),
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
"SDXL",
|
251 |
-
"Hugging Face",
|
252 |
-
"stabilityai/stable-diffusion-xl-base-1.0",
|
253 |
guidance_scale=7.0,
|
254 |
guidance_scale_min=1.0,
|
255 |
guidance_scale_max=10.0,
|
@@ -265,6 +235,16 @@ preset = Preset(
|
|
265 |
"num_inference_steps",
|
266 |
],
|
267 |
),
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
)
|
|
|
4 |
|
5 |
@dataclass
|
6 |
class Txt2TxtPreset:
|
7 |
+
name: str
|
8 |
frequency_penalty: float
|
9 |
frequency_penalty_min: float
|
10 |
frequency_penalty_max: float
|
|
|
13 |
|
14 |
@dataclass
|
15 |
class Txt2ImgPreset:
|
|
|
16 |
name: str
|
|
|
|
|
17 |
guidance_scale: Optional[float] = None
|
18 |
guidance_scale_min: Optional[float] = None
|
19 |
guidance_scale_max: Optional[float] = None
|
|
|
25 |
|
26 |
|
27 |
@dataclass
|
28 |
+
class Preset:
|
29 |
+
txt2txt: Dict[str, Txt2TxtPreset]
|
30 |
+
txt2img: Dict[str, Txt2ImgPreset]
|
|
|
|
|
|
|
|
|
|
|
31 |
|
|
|
|
|
32 |
|
33 |
+
hf_txt2txt_kwargs = {
|
34 |
+
"frequency_penalty": 0.0,
|
35 |
+
"frequency_penalty_min": -2.0,
|
36 |
+
"frequency_penalty_max": 2.0,
|
37 |
+
"parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
|
38 |
+
}
|
39 |
|
40 |
+
pplx_txt2txt_kwargs = {
|
41 |
+
"frequency_penalty": 1.0,
|
42 |
+
"frequency_penalty_min": 1.0,
|
43 |
+
"frequency_penalty_max": 2.0,
|
44 |
+
"parameters": ["max_tokens", "temperature", "frequency_penalty"],
|
45 |
+
}
|
46 |
|
47 |
|
48 |
preset = Preset(
|
49 |
+
txt2txt={
|
50 |
+
"hf": {
|
51 |
+
# TODO: update models
|
52 |
+
"codellama/codellama-34b-instruct-hf": Txt2TxtPreset("Code Llama 34B", **hf_txt2txt_kwargs),
|
53 |
+
"meta-llama/llama-2-13b-chat-hf": Txt2TxtPreset("Llama 2 13B", **hf_txt2txt_kwargs),
|
54 |
+
"mistralai/mistral-7b-instruct-v0.2": Txt2TxtPreset("Mistral v0.2 7B", **hf_txt2txt_kwargs),
|
55 |
+
"nousresearch/nous-hermes-2-mixtral-8x7b-dpo": Txt2TxtPreset(
|
56 |
+
"Nous Hermes 2 Mixtral 8x7B",
|
57 |
+
**hf_txt2txt_kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
),
|
59 |
+
},
|
60 |
+
"pplx": {
|
61 |
+
"llama-3.1-sonar-small-128k-chat": Txt2TxtPreset("Sonar Small (Offline)", **pplx_txt2txt_kwargs),
|
62 |
+
"llama-3.1-sonar-large-128k-chat": Txt2TxtPreset("Sonar Large (Offline)", **pplx_txt2txt_kwargs),
|
63 |
+
"llama-3.1-sonar-small-128k-online": Txt2TxtPreset("Sonar Small (Online)", **pplx_txt2txt_kwargs),
|
64 |
+
"llama-3.1-sonar-large-128k-online": Txt2TxtPreset("Sonar Large (Online)", **pplx_txt2txt_kwargs),
|
65 |
+
"llama-3.1-sonar-huge-128k-online": Txt2TxtPreset("Sonar Huge (Online)", **pplx_txt2txt_kwargs),
|
66 |
+
},
|
67 |
+
},
|
68 |
+
txt2img={
|
69 |
+
"bfl": {
|
70 |
+
"flux-pro-1.1": Txt2ImgPreset(
|
71 |
"FLUX1.1 Pro",
|
|
|
|
|
72 |
parameters=["seed", "width", "height", "prompt_upsampling"],
|
73 |
kwargs={"safety_tolerance": 6},
|
74 |
),
|
75 |
+
"flux-pro": Txt2ImgPreset(
|
76 |
"FLUX.1 Pro",
|
|
|
|
|
77 |
guidance_scale=2.5,
|
78 |
guidance_scale_min=1.5,
|
79 |
guidance_scale_max=5.0,
|
|
|
83 |
parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
|
84 |
kwargs={"safety_tolerance": 6, "interval": 1},
|
85 |
),
|
86 |
+
"flux-dev": Txt2ImgPreset(
|
87 |
"FLUX.1 Dev",
|
|
|
|
|
88 |
num_inference_steps=28,
|
89 |
num_inference_steps_min=10,
|
90 |
num_inference_steps_max=50,
|
|
|
94 |
parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
|
95 |
kwargs={"safety_tolerance": 6},
|
96 |
),
|
97 |
+
},
|
98 |
+
"fal": {
|
99 |
+
"fal-ai/aura-flow": Txt2ImgPreset(
|
100 |
+
"AuraFlow",
|
101 |
+
guidance_scale=3.5,
|
102 |
+
guidance_scale_min=1.0,
|
103 |
+
guidance_scale_max=10.0,
|
104 |
+
num_inference_steps=28,
|
105 |
+
num_inference_steps_min=10,
|
106 |
+
num_inference_steps_max=50,
|
107 |
+
parameters=["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
|
108 |
+
kwargs={"num_images": 1, "sync_mode": False},
|
109 |
+
),
|
110 |
+
"fal-ai/flux-pro/v1.1": Txt2ImgPreset(
|
111 |
"FLUX1.1 Pro",
|
|
|
|
|
112 |
parameters=["seed", "image_size"],
|
113 |
kwargs={
|
114 |
"num_images": 1,
|
|
|
117 |
"enable_safety_checker": False,
|
118 |
},
|
119 |
),
|
120 |
+
"fal-ai/flux-pro": Txt2ImgPreset(
|
121 |
"FLUX.1 Pro",
|
|
|
|
|
122 |
guidance_scale=2.5,
|
123 |
guidance_scale_min=1.5,
|
124 |
guidance_scale_max=5.0,
|
|
|
128 |
parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
129 |
kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
|
130 |
),
|
131 |
+
"fal-ai/flux/dev": Txt2ImgPreset(
|
132 |
"FLUX.1 Dev",
|
|
|
|
|
133 |
num_inference_steps=28,
|
134 |
num_inference_steps_min=10,
|
135 |
num_inference_steps_max=50,
|
|
|
139 |
parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
140 |
kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
|
141 |
),
|
142 |
+
"fal-ai/flux/schnell": Txt2ImgPreset(
|
143 |
"FLUX.1 Schnell",
|
|
|
|
|
144 |
num_inference_steps=4,
|
145 |
num_inference_steps_min=1,
|
146 |
num_inference_steps_max=12,
|
147 |
parameters=["seed", "image_size", "num_inference_steps"],
|
148 |
kwargs={"num_images": 1, "sync_mode": False, "enable_safety_checker": False},
|
149 |
),
|
150 |
+
"fal-ai/fooocus": Txt2ImgPreset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
"Fooocus",
|
|
|
|
|
152 |
guidance_scale=4.0,
|
153 |
guidance_scale_min=1.0,
|
154 |
guidance_scale_max=10.0,
|
|
|
163 |
"performance": "Quality",
|
164 |
},
|
165 |
),
|
166 |
+
"fal-ai/kolors": Txt2ImgPreset(
|
167 |
"Kolors",
|
|
|
|
|
168 |
guidance_scale=5.0,
|
169 |
guidance_scale_min=1.0,
|
170 |
guidance_scale_max=10.0,
|
|
|
179 |
"scheduler": "EulerDiscreteScheduler",
|
180 |
},
|
181 |
),
|
182 |
+
"fal-ai/stable-diffusion-v3-medium": Txt2ImgPreset(
|
183 |
"SD3",
|
|
|
|
|
184 |
guidance_scale=5.0,
|
185 |
guidance_scale_min=1.0,
|
186 |
guidance_scale_max=10.0,
|
|
|
197 |
],
|
198 |
kwargs={"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
|
199 |
),
|
200 |
+
},
|
201 |
+
"hf": {
|
202 |
+
"black-forest-labs/flux.1-dev": Txt2ImgPreset(
|
203 |
+
"FLUX.1 Dev",
|
204 |
+
num_inference_steps=28,
|
205 |
+
num_inference_steps_min=10,
|
206 |
+
num_inference_steps_max=50,
|
207 |
+
guidance_scale=3.0,
|
208 |
+
guidance_scale_min=1.5,
|
209 |
+
guidance_scale_max=5.0,
|
210 |
+
parameters=["width", "height", "guidance_scale", "num_inference_steps"],
|
211 |
+
kwargs={"max_sequence_length": 512},
|
212 |
+
),
|
213 |
+
"black-forest-labs/flux.1-schnell": Txt2ImgPreset(
|
214 |
+
"FLUX.1 Schnell",
|
215 |
+
num_inference_steps=4,
|
216 |
+
num_inference_steps_min=1,
|
217 |
+
num_inference_steps_max=12,
|
218 |
+
parameters=["width", "height", "num_inference_steps"],
|
219 |
+
kwargs={"guidance_scale": 0.0, "max_sequence_length": 256},
|
220 |
+
),
|
221 |
+
"stabilityai/stable-diffusion-xl-base-1.0": Txt2ImgPreset(
|
222 |
"SDXL",
|
|
|
|
|
223 |
guidance_scale=7.0,
|
224 |
guidance_scale_min=1.0,
|
225 |
guidance_scale_max=10.0,
|
|
|
235 |
"num_inference_steps",
|
236 |
],
|
237 |
),
|
238 |
+
},
|
239 |
+
"together": {
|
240 |
+
"black-forest-labs/FLUX.1-schnell-Free": Txt2ImgPreset(
|
241 |
+
"FLUX.1 Schnell Free",
|
242 |
+
num_inference_steps=4,
|
243 |
+
num_inference_steps_min=1,
|
244 |
+
num_inference_steps_max=12,
|
245 |
+
parameters=["model", "seed", "width", "height", "steps"],
|
246 |
+
kwargs={"n": 1},
|
247 |
+
),
|
248 |
+
},
|
249 |
+
},
|
250 |
)
|
pages/1_π¬_Text_Generation.py
CHANGED
@@ -1,33 +1,23 @@
|
|
1 |
-
import os
|
2 |
from datetime import datetime
|
|
|
3 |
|
4 |
import streamlit as st
|
5 |
|
6 |
-
from lib import config, preset, txt2txt_generate
|
7 |
-
|
8 |
-
SERVICE_SESSION = {
|
9 |
-
"Hugging Face": "api_key_hugging_face",
|
10 |
-
"Perplexity": "api_key_perplexity",
|
11 |
-
}
|
12 |
-
|
13 |
-
SESSION_TOKEN = {
|
14 |
-
"api_key_hugging_face": os.environ.get("HF_TOKEN") or None,
|
15 |
-
"api_key_perplexity": os.environ.get("PPLX_API_KEY") or None,
|
16 |
-
}
|
17 |
|
18 |
# config
|
19 |
st.set_page_config(
|
20 |
page_title=f"{config.title} | Text Generation",
|
21 |
-
page_icon=config.
|
22 |
layout=config.layout,
|
23 |
)
|
24 |
|
25 |
# initialize state
|
26 |
-
if "
|
27 |
-
st.session_state.
|
28 |
|
29 |
-
if "
|
30 |
-
st.session_state.
|
31 |
|
32 |
if "running" not in st.session_state:
|
33 |
st.session_state.running = False
|
@@ -39,33 +29,41 @@ if "txt2txt_seed" not in st.session_state:
|
|
39 |
st.session_state.txt2txt_seed = 0
|
40 |
|
41 |
# sidebar
|
42 |
-
st.logo(
|
43 |
st.sidebar.header("Settings")
|
|
|
44 |
service = st.sidebar.selectbox(
|
45 |
"Service",
|
46 |
-
options=
|
47 |
-
|
48 |
disabled=st.session_state.running,
|
49 |
)
|
50 |
|
51 |
# disable API key input and hide value if set by environment variable (handle empty string value later)
|
52 |
-
for display_name, session_key in SERVICE_SESSION.items():
|
53 |
-
|
|
|
|
|
|
|
54 |
st.session_state[session_key] = st.sidebar.text_input(
|
55 |
"API Key",
|
56 |
type="password",
|
57 |
-
value="" if
|
58 |
-
disabled=bool(
|
59 |
-
help="Set by environment variable" if
|
60 |
)
|
61 |
|
|
|
|
|
62 |
model = st.sidebar.selectbox(
|
63 |
"Model",
|
64 |
-
options=
|
65 |
-
|
66 |
disabled=st.session_state.running,
|
67 |
-
format_func=lambda x: x.split("/")[1] if service == "Hugging Face" else x,
|
68 |
)
|
|
|
|
|
|
|
69 |
system = st.sidebar.text_area(
|
70 |
"System Message",
|
71 |
value=config.txt2txt.default_system,
|
@@ -74,9 +72,7 @@ system = st.sidebar.text_area(
|
|
74 |
|
75 |
# build parameters from preset
|
76 |
parameters = {}
|
77 |
-
|
78 |
-
service_preset = getattr(preset.txt2txt, service_key)
|
79 |
-
for param in service_preset.parameters:
|
80 |
if param == "max_tokens":
|
81 |
parameters[param] = st.sidebar.slider(
|
82 |
"Max Tokens",
|
@@ -101,9 +97,9 @@ for param in service_preset.parameters:
|
|
101 |
parameters[param] = st.sidebar.slider(
|
102 |
"Frequency Penalty",
|
103 |
step=0.1,
|
104 |
-
value=
|
105 |
-
min_value=
|
106 |
-
max_value=
|
107 |
disabled=st.session_state.running,
|
108 |
help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
|
109 |
)
|
@@ -180,8 +176,8 @@ if prompt := st.chat_input(
|
|
180 |
st.markdown(prompt)
|
181 |
|
182 |
with st.chat_message("assistant"):
|
183 |
-
session_key = f"api_key_{service
|
184 |
-
api_key = st.session_state[session_key] or
|
185 |
response = txt2txt_generate(api_key, service, model, parameters)
|
186 |
st.session_state.running = False
|
187 |
|
|
|
|
|
1 |
from datetime import datetime
|
2 |
+
from typing import Dict
|
3 |
|
4 |
import streamlit as st
|
5 |
|
6 |
+
from lib import Txt2TxtPreset, config, preset, txt2txt_generate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# config
|
9 |
st.set_page_config(
|
10 |
page_title=f"{config.title} | Text Generation",
|
11 |
+
page_icon=config.logo,
|
12 |
layout=config.layout,
|
13 |
)
|
14 |
|
15 |
# initialize state
|
16 |
+
if "api_key_hf" not in st.session_state:
|
17 |
+
st.session_state.api_key_hf = ""
|
18 |
|
19 |
+
if "api_key_pplx" not in st.session_state:
|
20 |
+
st.session_state.api_key_pplx = ""
|
21 |
|
22 |
if "running" not in st.session_state:
|
23 |
st.session_state.running = False
|
|
|
29 |
st.session_state.txt2txt_seed = 0
|
30 |
|
31 |
# sidebar
|
32 |
+
st.logo(config.logo)
|
33 |
st.sidebar.header("Settings")
|
34 |
+
|
35 |
service = st.sidebar.selectbox(
|
36 |
"Service",
|
37 |
+
options=preset.txt2txt.keys(),
|
38 |
+
format_func=lambda x: config.service[x].name,
|
39 |
disabled=st.session_state.running,
|
40 |
)
|
41 |
|
42 |
# disable API key input and hide value if set by environment variable (handle empty string value later)
|
43 |
+
# for display_name, session_key in SERVICE_SESSION.items():
|
44 |
+
for service_id, service_config in config.service.items():
|
45 |
+
if service == service_id:
|
46 |
+
session_key = f"api_key_{service}"
|
47 |
+
api_key = config.service[service].api_key
|
48 |
st.session_state[session_key] = st.sidebar.text_input(
|
49 |
"API Key",
|
50 |
type="password",
|
51 |
+
value="" if api_key else st.session_state[session_key],
|
52 |
+
disabled=bool(api_key) or st.session_state.running,
|
53 |
+
help="Set by environment variable" if api_key else "Cleared on page refresh",
|
54 |
)
|
55 |
|
56 |
+
service_preset: Dict[str, Txt2TxtPreset] = preset.txt2txt[service]
|
57 |
+
|
58 |
model = st.sidebar.selectbox(
|
59 |
"Model",
|
60 |
+
options=service_preset.keys(),
|
61 |
+
format_func=lambda x: service_preset[x].name,
|
62 |
disabled=st.session_state.running,
|
|
|
63 |
)
|
64 |
+
|
65 |
+
model_preset = service_preset[model]
|
66 |
+
|
67 |
system = st.sidebar.text_area(
|
68 |
"System Message",
|
69 |
value=config.txt2txt.default_system,
|
|
|
72 |
|
73 |
# build parameters from preset
|
74 |
parameters = {}
|
75 |
+
for param in model_preset.parameters:
|
|
|
|
|
76 |
if param == "max_tokens":
|
77 |
parameters[param] = st.sidebar.slider(
|
78 |
"Max Tokens",
|
|
|
97 |
parameters[param] = st.sidebar.slider(
|
98 |
"Frequency Penalty",
|
99 |
step=0.1,
|
100 |
+
value=model_preset.frequency_penalty,
|
101 |
+
min_value=model_preset.frequency_penalty_min,
|
102 |
+
max_value=model_preset.frequency_penalty_max,
|
103 |
disabled=st.session_state.running,
|
104 |
help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
|
105 |
)
|
|
|
176 |
st.markdown(prompt)
|
177 |
|
178 |
with st.chat_message("assistant"):
|
179 |
+
session_key = f"api_key_{service}"
|
180 |
+
api_key = st.session_state[session_key] or config.service[service].api_key
|
181 |
response = txt2txt_generate(api_key, service, model, parameters)
|
182 |
st.session_state.running = False
|
183 |
|
pages/2_π¨_Text_to_Image.py
CHANGED
@@ -1,44 +1,25 @@
|
|
1 |
-
import os
|
2 |
from datetime import datetime
|
|
|
3 |
|
4 |
import streamlit as st
|
5 |
|
6 |
-
from lib import config, preset, txt2img_generate
|
7 |
-
|
8 |
-
# The token name is the service in lower_snake_case
|
9 |
-
SERVICE_SESSION = {
|
10 |
-
"Black Forest Labs": "api_key_black_forest_labs",
|
11 |
-
"Fal": "api_key_fal",
|
12 |
-
"Hugging Face": "api_key_hugging_face",
|
13 |
-
"Together": "api_key_together",
|
14 |
-
}
|
15 |
-
|
16 |
-
SESSION_TOKEN = {
|
17 |
-
"api_key_black_forest_labs": os.environ.get("BFL_API_KEY") or None,
|
18 |
-
"api_key_fal": os.environ.get("FAL_KEY") or None,
|
19 |
-
"api_key_hugging_face": os.environ.get("HF_TOKEN") or None,
|
20 |
-
"api_key_together": os.environ.get("TOGETHER_API_KEY") or None,
|
21 |
-
}
|
22 |
-
|
23 |
-
PRESET_MODEL = {}
|
24 |
-
for p in preset.txt2img.presets:
|
25 |
-
PRESET_MODEL[p.model_id] = p
|
26 |
|
27 |
st.set_page_config(
|
28 |
page_title=f"{config.title} | Text to Image",
|
29 |
-
page_icon=config.
|
30 |
layout=config.layout,
|
31 |
)
|
32 |
|
33 |
# Initialize Streamlit session state
|
34 |
-
if "
|
35 |
-
st.session_state.
|
36 |
|
37 |
if "api_key_fal" not in st.session_state:
|
38 |
st.session_state.api_key_fal = ""
|
39 |
|
40 |
-
if "
|
41 |
-
st.session_state.
|
42 |
|
43 |
if "api_key_together" not in st.session_state:
|
44 |
st.session_state.api_key_together = ""
|
@@ -52,34 +33,42 @@ if "txt2img_messages" not in st.session_state:
|
|
52 |
if "txt2img_seed" not in st.session_state:
|
53 |
st.session_state.txt2img_seed = 0
|
54 |
|
55 |
-
st.logo(
|
56 |
st.sidebar.header("Settings")
|
|
|
57 |
service = st.sidebar.selectbox(
|
58 |
"Service",
|
59 |
-
options=
|
|
|
60 |
disabled=st.session_state.running,
|
61 |
-
index=2, # Hugging Face
|
62 |
)
|
63 |
|
64 |
# Show the API key input for the selected service.
|
65 |
# Disable and hide value if set by environment variable; handle empty string value later.
|
66 |
-
for display_name, session_key in SERVICE_SESSION.items():
|
67 |
-
|
|
|
|
|
|
|
68 |
st.session_state[session_key] = st.sidebar.text_input(
|
69 |
"API Key",
|
70 |
type="password",
|
71 |
-
value="" if
|
72 |
-
disabled=bool(
|
73 |
-
help="Set by environment variable" if
|
74 |
)
|
75 |
|
|
|
|
|
76 |
model = st.sidebar.selectbox(
|
77 |
"Model",
|
78 |
-
options=
|
79 |
-
|
80 |
disabled=st.session_state.running,
|
81 |
)
|
82 |
|
|
|
|
|
83 |
# heading
|
84 |
st.html("""
|
85 |
<h1>Text to Image</h1>
|
@@ -88,7 +77,6 @@ st.html("""
|
|
88 |
|
89 |
# Build parameters from preset by rendering the appropriate input widgets
|
90 |
parameters = {}
|
91 |
-
model_preset = PRESET_MODEL[model]
|
92 |
for param in model_preset.parameters:
|
93 |
if param == "model":
|
94 |
parameters[param] = model
|
@@ -262,13 +250,13 @@ if prompt := st.chat_input(
|
|
262 |
with st.spinner("Running..."):
|
263 |
if model_preset.kwargs:
|
264 |
parameters.update(model_preset.kwargs)
|
265 |
-
session_key = f"api_key_{service
|
266 |
-
api_key = st.session_state[session_key] or
|
267 |
image = txt2img_generate(api_key, service, model, prompt, parameters)
|
268 |
st.session_state.running = False
|
269 |
|
270 |
st.session_state.txt2img_messages.append(
|
271 |
-
{"role": "user", "content": prompt, "parameters": parameters, "model":
|
272 |
)
|
273 |
st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
|
274 |
st.rerun()
|
|
|
|
|
1 |
from datetime import datetime
|
2 |
+
from typing import Dict
|
3 |
|
4 |
import streamlit as st
|
5 |
|
6 |
+
from lib import Txt2ImgPreset, config, preset, txt2img_generate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
st.set_page_config(
|
9 |
page_title=f"{config.title} | Text to Image",
|
10 |
+
page_icon=config.logo,
|
11 |
layout=config.layout,
|
12 |
)
|
13 |
|
14 |
# Initialize Streamlit session state
|
15 |
+
if "api_key_bfl" not in st.session_state:
|
16 |
+
st.session_state.api_key_bfl = ""
|
17 |
|
18 |
if "api_key_fal" not in st.session_state:
|
19 |
st.session_state.api_key_fal = ""
|
20 |
|
21 |
+
if "api_key_hf" not in st.session_state:
|
22 |
+
st.session_state.api_key_hf = ""
|
23 |
|
24 |
if "api_key_together" not in st.session_state:
|
25 |
st.session_state.api_key_together = ""
|
|
|
33 |
if "txt2img_seed" not in st.session_state:
|
34 |
st.session_state.txt2img_seed = 0
|
35 |
|
36 |
+
st.logo(config.logo)
|
37 |
st.sidebar.header("Settings")
|
38 |
+
|
39 |
service = st.sidebar.selectbox(
|
40 |
"Service",
|
41 |
+
options=preset.txt2img.keys(),
|
42 |
+
format_func=lambda x: config.service[x].name,
|
43 |
disabled=st.session_state.running,
|
|
|
44 |
)
|
45 |
|
46 |
# Show the API key input for the selected service.
|
47 |
# Disable and hide value if set by environment variable; handle empty string value later.
|
48 |
+
# for display_name, session_key in SERVICE_SESSION.items():
|
49 |
+
for service_id in config.service.keys():
|
50 |
+
if service == service_id:
|
51 |
+
session_key = f"api_key_{service}"
|
52 |
+
api_key = config.service[service].api_key
|
53 |
st.session_state[session_key] = st.sidebar.text_input(
|
54 |
"API Key",
|
55 |
type="password",
|
56 |
+
value="" if api_key else st.session_state[session_key],
|
57 |
+
disabled=bool(api_key) or st.session_state.running,
|
58 |
+
help="Set by environment variable" if api_key else "Cleared on page refresh",
|
59 |
)
|
60 |
|
61 |
+
service_preset: Dict[str, Txt2ImgPreset] = preset.txt2img[service]
|
62 |
+
|
63 |
model = st.sidebar.selectbox(
|
64 |
"Model",
|
65 |
+
options=service_preset.keys(),
|
66 |
+
format_func=lambda x: service_preset[x].name,
|
67 |
disabled=st.session_state.running,
|
68 |
)
|
69 |
|
70 |
+
model_preset = service_preset[model]
|
71 |
+
|
72 |
# heading
|
73 |
st.html("""
|
74 |
<h1>Text to Image</h1>
|
|
|
77 |
|
78 |
# Build parameters from preset by rendering the appropriate input widgets
|
79 |
parameters = {}
|
|
|
80 |
for param in model_preset.parameters:
|
81 |
if param == "model":
|
82 |
parameters[param] = model
|
|
|
250 |
with st.spinner("Running..."):
|
251 |
if model_preset.kwargs:
|
252 |
parameters.update(model_preset.kwargs)
|
253 |
+
session_key = f"api_key_{service}"
|
254 |
+
api_key = st.session_state[session_key] or config.service[service].api_key
|
255 |
image = txt2img_generate(api_key, service, model, prompt, parameters)
|
256 |
st.session_state.running = False
|
257 |
|
258 |
st.session_state.txt2img_messages.append(
|
259 |
+
{"role": "user", "content": prompt, "parameters": parameters, "model": model_preset.name}
|
260 |
)
|
261 |
st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
|
262 |
st.rerun()
|