adamelliotfields
commited on
Move presets to config
Browse files- lib/__init__.py +0 -4
- lib/api.py +10 -9
- lib/config.py +381 -89
- lib/preset.py +0 -250
- pages/1_💬_Text_Generation.py +28 -24
- pages/2_🎨_Text_to_Image.py +38 -35
lib/__init__.py
CHANGED
@@ -1,12 +1,8 @@
|
|
1 |
from .api import txt2img_generate, txt2txt_generate
|
2 |
from .config import config
|
3 |
-
from .preset import Txt2ImgPreset, Txt2TxtPreset, preset
|
4 |
|
5 |
__all__ = [
|
6 |
-
"Txt2ImgPreset",
|
7 |
-
"Txt2TxtPreset",
|
8 |
"config",
|
9 |
-
"preset",
|
10 |
"txt2img_generate",
|
11 |
"txt2txt_generate",
|
12 |
]
|
|
|
1 |
from .api import txt2img_generate, txt2txt_generate
|
2 |
from .config import config
|
|
|
3 |
|
4 |
__all__ = [
|
|
|
|
|
5 |
"config",
|
|
|
6 |
"txt2img_generate",
|
7 |
"txt2txt_generate",
|
8 |
]
|
lib/api.py
CHANGED
@@ -11,7 +11,7 @@ from .config import config
|
|
11 |
|
12 |
|
13 |
def txt2txt_generate(api_key, service, model, parameters, **kwargs):
|
14 |
-
base_url = config.
|
15 |
if service == "hf":
|
16 |
base_url = f"{base_url}/{model}/v1"
|
17 |
client = OpenAI(api_key=api_key, base_url=base_url)
|
@@ -52,23 +52,24 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
52 |
headers["Authorization"] = f"Bearer {api_key}"
|
53 |
json["prompt"] = inputs
|
54 |
|
55 |
-
base_url = config.
|
56 |
|
57 |
if service not in ["together"]:
|
58 |
base_url = f"{base_url}/{model}"
|
59 |
|
60 |
try:
|
61 |
-
|
|
|
62 |
if response.status_code // 100 == 2: # 2xx
|
63 |
# BFL is async so we need to poll for result
|
64 |
# https://api.bfl.ml/docs
|
65 |
if service == "bfl":
|
66 |
id = response.json()["id"]
|
67 |
-
url = f"{config.
|
68 |
|
69 |
retries = 0
|
70 |
-
while retries <
|
71 |
-
response = httpx.get(url, timeout=
|
72 |
if response.status_code // 100 != 2:
|
73 |
return f"Error: {response.status_code} {response.text}"
|
74 |
|
@@ -76,7 +77,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
76 |
image = httpx.get(
|
77 |
response.json()["result"]["sample"],
|
78 |
headers=headers,
|
79 |
-
timeout=
|
80 |
)
|
81 |
return Image.open(io.BytesIO(image.content))
|
82 |
|
@@ -92,7 +93,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
92 |
return Image.open(io.BytesIO(bytes))
|
93 |
else:
|
94 |
url = response.json()["images"][0]["url"]
|
95 |
-
image = httpx.get(url, headers=headers, timeout=
|
96 |
return Image.open(io.BytesIO(image.content))
|
97 |
|
98 |
if service == "hf":
|
@@ -100,7 +101,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
|
100 |
|
101 |
if service == "together":
|
102 |
url = response.json()["data"][0]["url"]
|
103 |
-
image = httpx.get(url, headers=headers, timeout=
|
104 |
return Image.open(io.BytesIO(image.content))
|
105 |
|
106 |
else:
|
|
|
11 |
|
12 |
|
13 |
def txt2txt_generate(api_key, service, model, parameters, **kwargs):
|
14 |
+
base_url = config.services[service].url
|
15 |
if service == "hf":
|
16 |
base_url = f"{base_url}/{model}/v1"
|
17 |
client = OpenAI(api_key=api_key, base_url=base_url)
|
|
|
52 |
headers["Authorization"] = f"Bearer {api_key}"
|
53 |
json["prompt"] = inputs
|
54 |
|
55 |
+
base_url = config.services[service].url
|
56 |
|
57 |
if service not in ["together"]:
|
58 |
base_url = f"{base_url}/{model}"
|
59 |
|
60 |
try:
|
61 |
+
timeout = config.timeout
|
62 |
+
response = httpx.post(base_url, headers=headers, json=json, timeout=timeout)
|
63 |
if response.status_code // 100 == 2: # 2xx
|
64 |
# BFL is async so we need to poll for result
|
65 |
# https://api.bfl.ml/docs
|
66 |
if service == "bfl":
|
67 |
id = response.json()["id"]
|
68 |
+
url = f"{config.services[service].url}/get_result?id={id}"
|
69 |
|
70 |
retries = 0
|
71 |
+
while retries < timeout:
|
72 |
+
response = httpx.get(url, timeout=timeout)
|
73 |
if response.status_code // 100 != 2:
|
74 |
return f"Error: {response.status_code} {response.text}"
|
75 |
|
|
|
77 |
image = httpx.get(
|
78 |
response.json()["result"]["sample"],
|
79 |
headers=headers,
|
80 |
+
timeout=timeout,
|
81 |
)
|
82 |
return Image.open(io.BytesIO(image.content))
|
83 |
|
|
|
93 |
return Image.open(io.BytesIO(bytes))
|
94 |
else:
|
95 |
url = response.json()["images"][0]["url"]
|
96 |
+
image = httpx.get(url, headers=headers, timeout=timeout)
|
97 |
return Image.open(io.BytesIO(image.content))
|
98 |
|
99 |
if service == "hf":
|
|
|
101 |
|
102 |
if service == "together":
|
103 |
url = response.json()["data"][0]["url"]
|
104 |
+
image = httpx.get(url, headers=headers, timeout=timeout)
|
105 |
return Image.open(io.BytesIO(image.content))
|
106 |
|
107 |
else:
|
lib/config.py
CHANGED
@@ -1,123 +1,415 @@
|
|
1 |
import os
|
2 |
-
from dataclasses import dataclass
|
3 |
-
from typing import Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
@dataclass
|
7 |
-
class
|
8 |
name: str
|
9 |
-
|
10 |
-
|
11 |
|
12 |
|
13 |
@dataclass
|
14 |
-
class
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
@dataclass
|
19 |
-
class
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
@dataclass
|
30 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
title: str
|
32 |
layout: str
|
33 |
logo: str
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
config =
|
40 |
title="API Inference",
|
41 |
layout="wide",
|
42 |
logo="logo.png",
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
"bfl": ServiceConfig(
|
45 |
-
"Black Forest Labs",
|
46 |
-
"https://api.bfl.ml/v1",
|
47 |
-
os.environ.get("BFL_API_KEY"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
),
|
49 |
"fal": ServiceConfig(
|
50 |
-
"Fal",
|
51 |
-
"https://fal.run",
|
52 |
-
os.environ.get("FAL_KEY"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
),
|
54 |
"hf": ServiceConfig(
|
55 |
-
"Hugging Face",
|
56 |
-
"https://api-inference.huggingface.co/models",
|
57 |
-
os.environ.get("HF_TOKEN"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
),
|
59 |
"pplx": ServiceConfig(
|
60 |
-
"Perplexity",
|
61 |
-
"https://api.perplexity.ai",
|
62 |
-
os.environ.get("PPLX_API_KEY"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
),
|
|
|
64 |
"together": ServiceConfig(
|
65 |
-
"Together",
|
66 |
-
"https://api.together.xyz/v1/images/generations",
|
67 |
-
os.environ.get("TOGETHER_API_KEY"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
),
|
69 |
},
|
70 |
-
txt2img=Txt2ImgConfig(
|
71 |
-
hidden_parameters=[
|
72 |
-
# Sent to API but not shown in generation parameters accordion
|
73 |
-
"enable_safety_checker",
|
74 |
-
"max_sequence_length",
|
75 |
-
"n",
|
76 |
-
"num_images",
|
77 |
-
"output_format",
|
78 |
-
"performance",
|
79 |
-
"safety_tolerance",
|
80 |
-
"scheduler",
|
81 |
-
"sharpness",
|
82 |
-
"style",
|
83 |
-
"styles",
|
84 |
-
"sync_mode",
|
85 |
-
],
|
86 |
-
negative_prompt="ugly, unattractive, disfigured, deformed, mutated, malformed, blurry, grainy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, sloppy, cluttered",
|
87 |
-
default_image_size="square_hd",
|
88 |
-
image_sizes=[
|
89 |
-
"landscape_16_9",
|
90 |
-
"landscape_4_3",
|
91 |
-
"square_hd",
|
92 |
-
"portrait_4_3",
|
93 |
-
"portrait_16_9",
|
94 |
-
],
|
95 |
-
default_aspect_ratio="1024x1024", # fooocus aspect ratios
|
96 |
-
aspect_ratios=[
|
97 |
-
"704x1408", # 1:2
|
98 |
-
"704x1344", # 11:21
|
99 |
-
"768x1344", # 4:7
|
100 |
-
"768x1280", # 3:5
|
101 |
-
"832x1216", # 13:19
|
102 |
-
"832x1152", # 13:18
|
103 |
-
"896x1152", # 7:9
|
104 |
-
"896x1088", # 14:17
|
105 |
-
"960x1088", # 15:17
|
106 |
-
"960x1024", # 15:16
|
107 |
-
"1024x1024",
|
108 |
-
"1024x960", # 16:15
|
109 |
-
"1088x960", # 17:15
|
110 |
-
"1088x896", # 17:14
|
111 |
-
"1152x896", # 9:7
|
112 |
-
"1152x832", # 18:13
|
113 |
-
"1216x832", # 19:13
|
114 |
-
"1280x768", # 5:3
|
115 |
-
"1344x768", # 7:4
|
116 |
-
"1344x704", # 21:11
|
117 |
-
"1408x704", # 2:1
|
118 |
-
],
|
119 |
-
),
|
120 |
-
txt2txt=Txt2TxtConfig(
|
121 |
-
default_system="You are a helpful assistant. Be precise and concise.",
|
122 |
-
),
|
123 |
)
|
|
|
1 |
import os
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from typing import Dict, List, Optional, Union
|
4 |
+
|
5 |
+
TEXT_SYSTEM_PROMPT = "You are a helpful assistant. Be precise and concise."
|
6 |
+
|
7 |
+
IMAGE_NEGATIVE_PROMPT = "ugly, unattractive, disfigured, deformed, mutated, malformed, blurry, grainy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, sloppy, cluttered"
|
8 |
+
|
9 |
+
IMAGE_IMAGE_SIZES = [
|
10 |
+
"landscape_16_9",
|
11 |
+
"landscape_4_3",
|
12 |
+
"square_hd",
|
13 |
+
"portrait_4_3",
|
14 |
+
"portrait_16_9",
|
15 |
+
]
|
16 |
+
|
17 |
+
IMAGE_ASPECT_RATIOS = [
|
18 |
+
"704x1408", # 1:2
|
19 |
+
"704x1344", # 11:21
|
20 |
+
"768x1344", # 4:7
|
21 |
+
"768x1280", # 3:5
|
22 |
+
"832x1216", # 13:19
|
23 |
+
"832x1152", # 13:18
|
24 |
+
"896x1152", # 7:9
|
25 |
+
"896x1088", # 14:17
|
26 |
+
"960x1088", # 15:17
|
27 |
+
"960x1024", # 15:16
|
28 |
+
"1024x1024",
|
29 |
+
"1024x960", # 16:15
|
30 |
+
"1088x960", # 17:15
|
31 |
+
"1088x896", # 17:14
|
32 |
+
"1152x896", # 9:7
|
33 |
+
"1152x832", # 18:13
|
34 |
+
"1216x832", # 19:13
|
35 |
+
"1280x768", # 5:3
|
36 |
+
"1344x768", # 7:4
|
37 |
+
"1344x704", # 21:11
|
38 |
+
"1408x704", # 2:1
|
39 |
+
]
|
40 |
+
|
41 |
+
IMAGE_RANGE = (256, 1408)
|
42 |
|
43 |
|
44 |
@dataclass
|
45 |
+
class ModelConfig:
|
46 |
name: str
|
47 |
+
parameters: List[str]
|
48 |
+
kwargs: Optional[Dict[str, Union[str, int, float, bool]]] = field(default_factory=dict)
|
49 |
|
50 |
|
51 |
@dataclass
|
52 |
+
class TextModelConfig(ModelConfig):
|
53 |
+
system_prompt: Optional[str] = None
|
54 |
+
frequency_penalty: Optional[float] = None
|
55 |
+
frequency_penalty_range: Optional[tuple[float, float]] = None
|
56 |
+
max_tokens: Optional[int] = None
|
57 |
+
max_tokens_range: Optional[tuple[int, int]] = None
|
58 |
+
temperature: Optional[float] = None
|
59 |
+
temperature_range: Optional[tuple[float, float]] = None
|
60 |
|
61 |
|
62 |
@dataclass
|
63 |
+
class ImageModelConfig(ModelConfig):
|
64 |
+
negative_prompt: Optional[str] = None
|
65 |
+
width: Optional[int] = None
|
66 |
+
width_range: Optional[tuple[int, int]] = None
|
67 |
+
height: Optional[int] = None
|
68 |
+
height_range: Optional[tuple[int, int]] = None
|
69 |
+
image_size: Optional[str] = None
|
70 |
+
image_sizes: Optional[List[str]] = field(default_factory=list)
|
71 |
+
aspect_ratio: Optional[str] = None
|
72 |
+
aspect_ratios: Optional[List[str]] = field(default_factory=list)
|
73 |
+
guidance_scale: Optional[float] = None
|
74 |
+
guidance_scale_range: Optional[tuple[float, float]] = None
|
75 |
+
num_inference_steps: Optional[int] = None
|
76 |
+
num_inference_steps_range: Optional[tuple[int, int]] = None
|
77 |
|
78 |
|
79 |
@dataclass
|
80 |
+
class ServiceConfig:
|
81 |
+
name: str
|
82 |
+
url: str
|
83 |
+
api_key: Optional[str]
|
84 |
+
text: Optional[Dict[str, TextModelConfig]] = field(default_factory=dict)
|
85 |
+
image: Optional[Dict[str, ImageModelConfig]] = field(default_factory=dict)
|
86 |
+
|
87 |
+
|
88 |
+
@dataclass
|
89 |
+
class AppConfig:
|
90 |
title: str
|
91 |
layout: str
|
92 |
logo: str
|
93 |
+
timeout: int
|
94 |
+
hidden_parameters: List[str]
|
95 |
+
services: Dict[str, ServiceConfig]
|
96 |
+
|
97 |
+
|
98 |
+
_hf_text_kwargs = {
|
99 |
+
"system_prompt": TEXT_SYSTEM_PROMPT,
|
100 |
+
"frequency_penalty": 0.0,
|
101 |
+
"frequency_penalty_range": (-2.0, 2.0),
|
102 |
+
"max_tokens": 512,
|
103 |
+
"max_tokens_range": (512, 4096),
|
104 |
+
"temperature": 1.0,
|
105 |
+
"temperature_range": (0.0, 2.0),
|
106 |
+
"parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
|
107 |
+
}
|
108 |
|
109 |
+
_pplx_text_kwargs = {
|
110 |
+
"system_prompt": TEXT_SYSTEM_PROMPT,
|
111 |
+
"frequency_penalty": 1.0,
|
112 |
+
"frequency_penalty_range": (1.0, 2.0),
|
113 |
+
"max_tokens": 512,
|
114 |
+
"max_tokens_range": (512, 4096),
|
115 |
+
"temperature": 1.0,
|
116 |
+
"temperature_range": (0.0, 2.0),
|
117 |
+
"parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
|
118 |
+
}
|
119 |
|
120 |
+
config = AppConfig(
|
121 |
title="API Inference",
|
122 |
layout="wide",
|
123 |
logo="logo.png",
|
124 |
+
timeout=60,
|
125 |
+
hidden_parameters=[
|
126 |
+
# Sent to API but not shown in generation parameters accordion
|
127 |
+
"enable_safety_checker",
|
128 |
+
"max_sequence_length",
|
129 |
+
"n",
|
130 |
+
"num_images",
|
131 |
+
"output_format",
|
132 |
+
"performance",
|
133 |
+
"safety_tolerance",
|
134 |
+
"scheduler",
|
135 |
+
"sharpness",
|
136 |
+
"style",
|
137 |
+
"styles",
|
138 |
+
"sync_mode",
|
139 |
+
],
|
140 |
+
services={
|
141 |
"bfl": ServiceConfig(
|
142 |
+
name="Black Forest Labs",
|
143 |
+
url="https://api.bfl.ml/v1",
|
144 |
+
api_key=os.environ.get("BFL_API_KEY"),
|
145 |
+
image={
|
146 |
+
"flux-pro-1.1": ImageModelConfig(
|
147 |
+
"FLUX1.1 Pro",
|
148 |
+
width=1024,
|
149 |
+
width_range=IMAGE_RANGE,
|
150 |
+
height=1024,
|
151 |
+
height_range=IMAGE_RANGE,
|
152 |
+
parameters=["seed", "width", "height", "prompt_upsampling"],
|
153 |
+
kwargs={"safety_tolerance": 6},
|
154 |
+
),
|
155 |
+
"flux-pro": ImageModelConfig(
|
156 |
+
"FLUX.1 Pro",
|
157 |
+
width=1024,
|
158 |
+
width_range=IMAGE_RANGE,
|
159 |
+
height=1024,
|
160 |
+
height_range=IMAGE_RANGE,
|
161 |
+
guidance_scale=2.5,
|
162 |
+
guidance_scale_range=(1.5, 5.0),
|
163 |
+
num_inference_steps=50,
|
164 |
+
num_inference_steps_range=(10, 50),
|
165 |
+
parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
|
166 |
+
kwargs={"safety_tolerance": 6, "interval": 1},
|
167 |
+
),
|
168 |
+
"flux-dev": ImageModelConfig(
|
169 |
+
"FLUX.1 Dev",
|
170 |
+
width=1024,
|
171 |
+
width_range=IMAGE_RANGE,
|
172 |
+
height=1024,
|
173 |
+
height_range=IMAGE_RANGE,
|
174 |
+
num_inference_steps=28,
|
175 |
+
num_inference_steps_range=(10, 50),
|
176 |
+
guidance_scale=3.0,
|
177 |
+
guidance_scale_range=(1.5, 5.0),
|
178 |
+
parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
|
179 |
+
kwargs={"safety_tolerance": 6},
|
180 |
+
),
|
181 |
+
},
|
182 |
),
|
183 |
"fal": ServiceConfig(
|
184 |
+
name="Fal",
|
185 |
+
url="https://fal.run",
|
186 |
+
api_key=os.environ.get("FAL_KEY"),
|
187 |
+
image={
|
188 |
+
"fal-ai/aura-flow": ImageModelConfig(
|
189 |
+
"AuraFlow",
|
190 |
+
guidance_scale=3.5,
|
191 |
+
guidance_scale_range=(1.0, 10.0),
|
192 |
+
num_inference_steps=28,
|
193 |
+
num_inference_steps_range=(10, 50),
|
194 |
+
parameters=["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
|
195 |
+
kwargs={"num_images": 1, "sync_mode": False},
|
196 |
+
),
|
197 |
+
"fal-ai/flux-pro/v1.1": ImageModelConfig(
|
198 |
+
"FLUX1.1 Pro",
|
199 |
+
parameters=["seed", "image_size"],
|
200 |
+
image_size="square_hd",
|
201 |
+
image_sizes=IMAGE_IMAGE_SIZES,
|
202 |
+
kwargs={
|
203 |
+
"num_images": 1,
|
204 |
+
"sync_mode": False,
|
205 |
+
"safety_tolerance": 6,
|
206 |
+
"enable_safety_checker": False,
|
207 |
+
},
|
208 |
+
),
|
209 |
+
"fal-ai/flux-pro": ImageModelConfig(
|
210 |
+
"FLUX.1 Pro",
|
211 |
+
image_size="square_hd",
|
212 |
+
image_sizes=IMAGE_IMAGE_SIZES,
|
213 |
+
guidance_scale=2.5,
|
214 |
+
guidance_scale_range=(1.5, 5.0),
|
215 |
+
num_inference_steps=40,
|
216 |
+
num_inference_steps_range=(10, 50),
|
217 |
+
parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
218 |
+
kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
|
219 |
+
),
|
220 |
+
"fal-ai/flux/dev": ImageModelConfig(
|
221 |
+
"FLUX.1 Dev",
|
222 |
+
image_size="square_hd",
|
223 |
+
image_sizes=IMAGE_IMAGE_SIZES,
|
224 |
+
num_inference_steps=28,
|
225 |
+
num_inference_steps_range=(10, 50),
|
226 |
+
guidance_scale=3.0,
|
227 |
+
guidance_scale_range=(1.5, 5.0),
|
228 |
+
parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
229 |
+
kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
|
230 |
+
),
|
231 |
+
"fal-ai/flux/schnell": ImageModelConfig(
|
232 |
+
"FLUX.1 Schnell",
|
233 |
+
image_size="square_hd",
|
234 |
+
image_sizes=IMAGE_IMAGE_SIZES,
|
235 |
+
num_inference_steps=4,
|
236 |
+
num_inference_steps_range=(1, 12),
|
237 |
+
parameters=["seed", "image_size", "num_inference_steps"],
|
238 |
+
kwargs={"num_images": 1, "sync_mode": False, "enable_safety_checker": False},
|
239 |
+
),
|
240 |
+
"fal-ai/fooocus": ImageModelConfig(
|
241 |
+
"Fooocus",
|
242 |
+
aspect_ratio="1024x1024",
|
243 |
+
aspect_ratios=IMAGE_ASPECT_RATIOS,
|
244 |
+
guidance_scale=4.0,
|
245 |
+
guidance_scale_range=(1.0, 15.0),
|
246 |
+
parameters=["seed", "negative_prompt", "aspect_ratio", "guidance_scale"],
|
247 |
+
# TODO: more of these can be params
|
248 |
+
kwargs={
|
249 |
+
"num_images": 1,
|
250 |
+
"sync_mode": True,
|
251 |
+
"enable_safety_checker": False,
|
252 |
+
"output_format": "png",
|
253 |
+
"sharpness": 2,
|
254 |
+
"styles": ["Fooocus Enhance", "Fooocus V2", "Fooocus Sharp"],
|
255 |
+
"performance": "Quality",
|
256 |
+
},
|
257 |
+
),
|
258 |
+
"fal-ai/kolors": ImageModelConfig(
|
259 |
+
"Kolors",
|
260 |
+
image_size="square_hd",
|
261 |
+
image_sizes=IMAGE_IMAGE_SIZES,
|
262 |
+
guidance_scale=5.0,
|
263 |
+
guidance_scale_range=(1.0, 10.0),
|
264 |
+
num_inference_steps=50,
|
265 |
+
num_inference_steps_range=(10, 50),
|
266 |
+
parameters=[
|
267 |
+
"seed",
|
268 |
+
"negative_prompt",
|
269 |
+
"image_size",
|
270 |
+
"guidance_scale",
|
271 |
+
"num_inference_steps",
|
272 |
+
],
|
273 |
+
kwargs={
|
274 |
+
"num_images": 1,
|
275 |
+
"sync_mode": True,
|
276 |
+
"enable_safety_checker": False,
|
277 |
+
"scheduler": "EulerDiscreteScheduler",
|
278 |
+
},
|
279 |
+
),
|
280 |
+
"fal-ai/stable-diffusion-v3-medium": ImageModelConfig(
|
281 |
+
"SD3",
|
282 |
+
image_size="square_hd",
|
283 |
+
image_sizes=IMAGE_IMAGE_SIZES,
|
284 |
+
guidance_scale=5.0,
|
285 |
+
guidance_scale_range=(1.0, 10.0),
|
286 |
+
num_inference_steps=28,
|
287 |
+
num_inference_steps_range=(10, 50),
|
288 |
+
parameters=[
|
289 |
+
"seed",
|
290 |
+
"negative_prompt",
|
291 |
+
"image_size",
|
292 |
+
"guidance_scale",
|
293 |
+
"num_inference_steps",
|
294 |
+
"prompt_expansion",
|
295 |
+
],
|
296 |
+
kwargs={"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
|
297 |
+
),
|
298 |
+
},
|
299 |
),
|
300 |
"hf": ServiceConfig(
|
301 |
+
name="Hugging Face",
|
302 |
+
url="https://api-inference.huggingface.co/models",
|
303 |
+
api_key=os.environ.get("HF_TOKEN"),
|
304 |
+
text={
|
305 |
+
"codellama/codellama-34b-instruct-hf": TextModelConfig(
|
306 |
+
"Code Llama 34B",
|
307 |
+
**_hf_text_kwargs,
|
308 |
+
),
|
309 |
+
"meta-llama/llama-2-13b-chat-hf": TextModelConfig(
|
310 |
+
"Meta Llama 2 13B",
|
311 |
+
**_hf_text_kwargs,
|
312 |
+
),
|
313 |
+
"mistralai/mistral-7b-instruct-v0.2": TextModelConfig(
|
314 |
+
"Mistral 0.2 7B",
|
315 |
+
**_hf_text_kwargs,
|
316 |
+
),
|
317 |
+
"nousresearch/nous-hermes-2-mixtral-8x7b-dpo": TextModelConfig(
|
318 |
+
"Nous Hermes 2 Mixtral 8x7B",
|
319 |
+
**_hf_text_kwargs,
|
320 |
+
),
|
321 |
+
},
|
322 |
+
image={
|
323 |
+
"black-forest-labs/flux.1-dev": ImageModelConfig(
|
324 |
+
"FLUX.1 Dev",
|
325 |
+
width=1024,
|
326 |
+
width_range=IMAGE_RANGE,
|
327 |
+
height=1024,
|
328 |
+
height_range=IMAGE_RANGE,
|
329 |
+
guidance_scale=3.0,
|
330 |
+
guidance_scale_range=(1.5, 5.0),
|
331 |
+
num_inference_steps=28,
|
332 |
+
num_inference_steps_range=(10, 50),
|
333 |
+
parameters=["width", "height", "guidance_scale", "num_inference_steps"],
|
334 |
+
),
|
335 |
+
"black-forest-labs/flux.1-schnell": ImageModelConfig(
|
336 |
+
"FLUX.1 Schnell",
|
337 |
+
width=1024,
|
338 |
+
width_range=IMAGE_RANGE,
|
339 |
+
height=1024,
|
340 |
+
height_range=IMAGE_RANGE,
|
341 |
+
num_inference_steps=4,
|
342 |
+
num_inference_steps_range=(1, 12),
|
343 |
+
parameters=["width", "height", "num_inference_steps"],
|
344 |
+
kwargs={"guidance_scale": 0.0, "max_sequence_length": 256},
|
345 |
+
),
|
346 |
+
"stabilityai/stable-diffusion-xl-base-1.0": ImageModelConfig(
|
347 |
+
"Stable Diffusion XL 1.0",
|
348 |
+
negative_prompt=IMAGE_NEGATIVE_PROMPT,
|
349 |
+
width=1024,
|
350 |
+
width_range=IMAGE_RANGE,
|
351 |
+
height=1024,
|
352 |
+
height_range=IMAGE_RANGE,
|
353 |
+
guidance_scale=7.0,
|
354 |
+
guidance_scale_range=(1.0, 15.0),
|
355 |
+
num_inference_steps=40,
|
356 |
+
num_inference_steps_range=(10, 50),
|
357 |
+
parameters=[
|
358 |
+
"seed",
|
359 |
+
"negative_prompt",
|
360 |
+
"width",
|
361 |
+
"height",
|
362 |
+
"guidance_scale",
|
363 |
+
"num_inference_steps",
|
364 |
+
],
|
365 |
+
),
|
366 |
+
},
|
367 |
),
|
368 |
"pplx": ServiceConfig(
|
369 |
+
name="Perplexity",
|
370 |
+
url="https://api.perplexity.ai",
|
371 |
+
api_key=os.environ.get("PPLX_API_KEY"),
|
372 |
+
text={
|
373 |
+
"llama-3.1-sonar-small-128k-chat": TextModelConfig(
|
374 |
+
"Sonar Small (Offline)",
|
375 |
+
**_pplx_text_kwargs,
|
376 |
+
),
|
377 |
+
"llama-3.1-sonar-large-128k-chat": TextModelConfig(
|
378 |
+
"Sonar Large (Offline)",
|
379 |
+
**_pplx_text_kwargs,
|
380 |
+
),
|
381 |
+
"llama-3.1-sonar-small-128k-online": TextModelConfig(
|
382 |
+
"Sonar Small (Online)",
|
383 |
+
**_pplx_text_kwargs,
|
384 |
+
),
|
385 |
+
"llama-3.1-sonar-large-128k-online": TextModelConfig(
|
386 |
+
"Sonar Large (Online)",
|
387 |
+
**_pplx_text_kwargs,
|
388 |
+
),
|
389 |
+
"llama-3.1-sonar-huge-128k-online": TextModelConfig(
|
390 |
+
"Sonar Huge (Online)",
|
391 |
+
**_pplx_text_kwargs,
|
392 |
+
),
|
393 |
+
},
|
394 |
),
|
395 |
+
# TODO: text models
|
396 |
"together": ServiceConfig(
|
397 |
+
name="Together",
|
398 |
+
url="https://api.together.xyz/v1/images/generations",
|
399 |
+
api_key=os.environ.get("TOGETHER_API_KEY"),
|
400 |
+
image={
|
401 |
+
"black-forest-labs/FLUX.1-schnell-Free": ImageModelConfig(
|
402 |
+
"FLUX.1 Schnell Free",
|
403 |
+
width=1024,
|
404 |
+
width_range=IMAGE_RANGE,
|
405 |
+
height=1024,
|
406 |
+
height_range=IMAGE_RANGE,
|
407 |
+
num_inference_steps=4,
|
408 |
+
num_inference_steps_range=(1, 12),
|
409 |
+
parameters=["model", "seed", "width", "height", "steps"],
|
410 |
+
kwargs={"n": 1},
|
411 |
+
),
|
412 |
+
},
|
413 |
),
|
414 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
)
|
lib/preset.py
DELETED
@@ -1,250 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass, field
|
2 |
-
from typing import Dict, List, Optional, Union
|
3 |
-
|
4 |
-
|
5 |
-
@dataclass
|
6 |
-
class Txt2TxtPreset:
|
7 |
-
name: str
|
8 |
-
frequency_penalty: float
|
9 |
-
frequency_penalty_min: float
|
10 |
-
frequency_penalty_max: float
|
11 |
-
parameters: Optional[List[str]] = field(default_factory=list)
|
12 |
-
|
13 |
-
|
14 |
-
@dataclass
|
15 |
-
class Txt2ImgPreset:
|
16 |
-
name: str
|
17 |
-
guidance_scale: Optional[float] = None
|
18 |
-
guidance_scale_min: Optional[float] = None
|
19 |
-
guidance_scale_max: Optional[float] = None
|
20 |
-
num_inference_steps: Optional[int] = None
|
21 |
-
num_inference_steps_min: Optional[int] = None
|
22 |
-
num_inference_steps_max: Optional[int] = None
|
23 |
-
parameters: Optional[List[str]] = field(default_factory=list)
|
24 |
-
kwargs: Optional[Dict[str, Union[str, int, float, bool]]] = field(default_factory=dict)
|
25 |
-
|
26 |
-
|
27 |
-
@dataclass
|
28 |
-
class Preset:
|
29 |
-
txt2txt: Dict[str, Txt2TxtPreset]
|
30 |
-
txt2img: Dict[str, Txt2ImgPreset]
|
31 |
-
|
32 |
-
|
33 |
-
hf_txt2txt_kwargs = {
|
34 |
-
"frequency_penalty": 0.0,
|
35 |
-
"frequency_penalty_min": -2.0,
|
36 |
-
"frequency_penalty_max": 2.0,
|
37 |
-
"parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
|
38 |
-
}
|
39 |
-
|
40 |
-
pplx_txt2txt_kwargs = {
|
41 |
-
"frequency_penalty": 1.0,
|
42 |
-
"frequency_penalty_min": 1.0,
|
43 |
-
"frequency_penalty_max": 2.0,
|
44 |
-
"parameters": ["max_tokens", "temperature", "frequency_penalty"],
|
45 |
-
}
|
46 |
-
|
47 |
-
|
48 |
-
preset = Preset(
|
49 |
-
txt2txt={
|
50 |
-
"hf": {
|
51 |
-
# TODO: update models
|
52 |
-
"codellama/codellama-34b-instruct-hf": Txt2TxtPreset("Code Llama 34B", **hf_txt2txt_kwargs),
|
53 |
-
"meta-llama/llama-2-13b-chat-hf": Txt2TxtPreset("Llama 2 13B", **hf_txt2txt_kwargs),
|
54 |
-
"mistralai/mistral-7b-instruct-v0.2": Txt2TxtPreset("Mistral v0.2 7B", **hf_txt2txt_kwargs),
|
55 |
-
"nousresearch/nous-hermes-2-mixtral-8x7b-dpo": Txt2TxtPreset(
|
56 |
-
"Nous Hermes 2 Mixtral 8x7B",
|
57 |
-
**hf_txt2txt_kwargs,
|
58 |
-
),
|
59 |
-
},
|
60 |
-
"pplx": {
|
61 |
-
"llama-3.1-sonar-small-128k-chat": Txt2TxtPreset("Sonar Small (Offline)", **pplx_txt2txt_kwargs),
|
62 |
-
"llama-3.1-sonar-large-128k-chat": Txt2TxtPreset("Sonar Large (Offline)", **pplx_txt2txt_kwargs),
|
63 |
-
"llama-3.1-sonar-small-128k-online": Txt2TxtPreset("Sonar Small (Online)", **pplx_txt2txt_kwargs),
|
64 |
-
"llama-3.1-sonar-large-128k-online": Txt2TxtPreset("Sonar Large (Online)", **pplx_txt2txt_kwargs),
|
65 |
-
"llama-3.1-sonar-huge-128k-online": Txt2TxtPreset("Sonar Huge (Online)", **pplx_txt2txt_kwargs),
|
66 |
-
},
|
67 |
-
},
|
68 |
-
txt2img={
|
69 |
-
"bfl": {
|
70 |
-
"flux-pro-1.1": Txt2ImgPreset(
|
71 |
-
"FLUX1.1 Pro",
|
72 |
-
parameters=["seed", "width", "height", "prompt_upsampling"],
|
73 |
-
kwargs={"safety_tolerance": 6},
|
74 |
-
),
|
75 |
-
"flux-pro": Txt2ImgPreset(
|
76 |
-
"FLUX.1 Pro",
|
77 |
-
guidance_scale=2.5,
|
78 |
-
guidance_scale_min=1.5,
|
79 |
-
guidance_scale_max=5.0,
|
80 |
-
num_inference_steps=40,
|
81 |
-
num_inference_steps_min=10,
|
82 |
-
num_inference_steps_max=50,
|
83 |
-
parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
|
84 |
-
kwargs={"safety_tolerance": 6, "interval": 1},
|
85 |
-
),
|
86 |
-
"flux-dev": Txt2ImgPreset(
|
87 |
-
"FLUX.1 Dev",
|
88 |
-
num_inference_steps=28,
|
89 |
-
num_inference_steps_min=10,
|
90 |
-
num_inference_steps_max=50,
|
91 |
-
guidance_scale=3.0,
|
92 |
-
guidance_scale_min=1.5,
|
93 |
-
guidance_scale_max=5.0,
|
94 |
-
parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
|
95 |
-
kwargs={"safety_tolerance": 6},
|
96 |
-
),
|
97 |
-
},
|
98 |
-
"fal": {
|
99 |
-
"fal-ai/aura-flow": Txt2ImgPreset(
|
100 |
-
"AuraFlow",
|
101 |
-
guidance_scale=3.5,
|
102 |
-
guidance_scale_min=1.0,
|
103 |
-
guidance_scale_max=10.0,
|
104 |
-
num_inference_steps=28,
|
105 |
-
num_inference_steps_min=10,
|
106 |
-
num_inference_steps_max=50,
|
107 |
-
parameters=["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
|
108 |
-
kwargs={"num_images": 1, "sync_mode": False},
|
109 |
-
),
|
110 |
-
"fal-ai/flux-pro/v1.1": Txt2ImgPreset(
|
111 |
-
"FLUX1.1 Pro",
|
112 |
-
parameters=["seed", "image_size"],
|
113 |
-
kwargs={
|
114 |
-
"num_images": 1,
|
115 |
-
"sync_mode": False,
|
116 |
-
"safety_tolerance": 6,
|
117 |
-
"enable_safety_checker": False,
|
118 |
-
},
|
119 |
-
),
|
120 |
-
"fal-ai/flux-pro": Txt2ImgPreset(
|
121 |
-
"FLUX.1 Pro",
|
122 |
-
guidance_scale=2.5,
|
123 |
-
guidance_scale_min=1.5,
|
124 |
-
guidance_scale_max=5.0,
|
125 |
-
num_inference_steps=40,
|
126 |
-
num_inference_steps_min=10,
|
127 |
-
num_inference_steps_max=50,
|
128 |
-
parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
129 |
-
kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
|
130 |
-
),
|
131 |
-
"fal-ai/flux/dev": Txt2ImgPreset(
|
132 |
-
"FLUX.1 Dev",
|
133 |
-
num_inference_steps=28,
|
134 |
-
num_inference_steps_min=10,
|
135 |
-
num_inference_steps_max=50,
|
136 |
-
guidance_scale=3.0,
|
137 |
-
guidance_scale_min=1.5,
|
138 |
-
guidance_scale_max=5.0,
|
139 |
-
parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
|
140 |
-
kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
|
141 |
-
),
|
142 |
-
"fal-ai/flux/schnell": Txt2ImgPreset(
|
143 |
-
"FLUX.1 Schnell",
|
144 |
-
num_inference_steps=4,
|
145 |
-
num_inference_steps_min=1,
|
146 |
-
num_inference_steps_max=12,
|
147 |
-
parameters=["seed", "image_size", "num_inference_steps"],
|
148 |
-
kwargs={"num_images": 1, "sync_mode": False, "enable_safety_checker": False},
|
149 |
-
),
|
150 |
-
"fal-ai/fooocus": Txt2ImgPreset(
|
151 |
-
"Fooocus",
|
152 |
-
guidance_scale=4.0,
|
153 |
-
guidance_scale_min=1.0,
|
154 |
-
guidance_scale_max=10.0,
|
155 |
-
parameters=["seed", "negative_prompt", "aspect_ratio", "guidance_scale"],
|
156 |
-
kwargs={
|
157 |
-
"num_images": 1,
|
158 |
-
"sync_mode": True,
|
159 |
-
"enable_safety_checker": False,
|
160 |
-
"output_format": "png",
|
161 |
-
"sharpness": 2,
|
162 |
-
"styles": ["Fooocus Enhance", "Fooocus V2", "Fooocus Sharp"],
|
163 |
-
"performance": "Quality",
|
164 |
-
},
|
165 |
-
),
|
166 |
-
"fal-ai/kolors": Txt2ImgPreset(
|
167 |
-
"Kolors",
|
168 |
-
guidance_scale=5.0,
|
169 |
-
guidance_scale_min=1.0,
|
170 |
-
guidance_scale_max=10.0,
|
171 |
-
num_inference_steps=50,
|
172 |
-
num_inference_steps_min=10,
|
173 |
-
num_inference_steps_max=50,
|
174 |
-
parameters=["seed", "negative_prompt", "image_size", "guidance_scale", "num_inference_steps"],
|
175 |
-
kwargs={
|
176 |
-
"num_images": 1,
|
177 |
-
"sync_mode": True,
|
178 |
-
"enable_safety_checker": False,
|
179 |
-
"scheduler": "EulerDiscreteScheduler",
|
180 |
-
},
|
181 |
-
),
|
182 |
-
"fal-ai/stable-diffusion-v3-medium": Txt2ImgPreset(
|
183 |
-
"SD3",
|
184 |
-
guidance_scale=5.0,
|
185 |
-
guidance_scale_min=1.0,
|
186 |
-
guidance_scale_max=10.0,
|
187 |
-
num_inference_steps=28,
|
188 |
-
num_inference_steps_min=10,
|
189 |
-
num_inference_steps_max=50,
|
190 |
-
parameters=[
|
191 |
-
"seed",
|
192 |
-
"negative_prompt",
|
193 |
-
"image_size",
|
194 |
-
"guidance_scale",
|
195 |
-
"num_inference_steps",
|
196 |
-
"prompt_expansion",
|
197 |
-
],
|
198 |
-
kwargs={"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
|
199 |
-
),
|
200 |
-
},
|
201 |
-
"hf": {
|
202 |
-
"black-forest-labs/flux.1-dev": Txt2ImgPreset(
|
203 |
-
"FLUX.1 Dev",
|
204 |
-
num_inference_steps=28,
|
205 |
-
num_inference_steps_min=10,
|
206 |
-
num_inference_steps_max=50,
|
207 |
-
guidance_scale=3.0,
|
208 |
-
guidance_scale_min=1.5,
|
209 |
-
guidance_scale_max=5.0,
|
210 |
-
parameters=["width", "height", "guidance_scale", "num_inference_steps"],
|
211 |
-
kwargs={"max_sequence_length": 512},
|
212 |
-
),
|
213 |
-
"black-forest-labs/flux.1-schnell": Txt2ImgPreset(
|
214 |
-
"FLUX.1 Schnell",
|
215 |
-
num_inference_steps=4,
|
216 |
-
num_inference_steps_min=1,
|
217 |
-
num_inference_steps_max=12,
|
218 |
-
parameters=["width", "height", "num_inference_steps"],
|
219 |
-
kwargs={"guidance_scale": 0.0, "max_sequence_length": 256},
|
220 |
-
),
|
221 |
-
"stabilityai/stable-diffusion-xl-base-1.0": Txt2ImgPreset(
|
222 |
-
"SDXL",
|
223 |
-
guidance_scale=7.0,
|
224 |
-
guidance_scale_min=1.0,
|
225 |
-
guidance_scale_max=10.0,
|
226 |
-
num_inference_steps=40,
|
227 |
-
num_inference_steps_min=10,
|
228 |
-
num_inference_steps_max=50,
|
229 |
-
parameters=[
|
230 |
-
"seed",
|
231 |
-
"negative_prompt",
|
232 |
-
"width",
|
233 |
-
"height",
|
234 |
-
"guidance_scale",
|
235 |
-
"num_inference_steps",
|
236 |
-
],
|
237 |
-
),
|
238 |
-
},
|
239 |
-
"together": {
|
240 |
-
"black-forest-labs/FLUX.1-schnell-Free": Txt2ImgPreset(
|
241 |
-
"FLUX.1 Schnell Free",
|
242 |
-
num_inference_steps=4,
|
243 |
-
num_inference_steps_min=1,
|
244 |
-
num_inference_steps_max=12,
|
245 |
-
parameters=["model", "seed", "width", "height", "steps"],
|
246 |
-
kwargs={"n": 1},
|
247 |
-
),
|
248 |
-
},
|
249 |
-
},
|
250 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/1_💬_Text_Generation.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
from datetime import datetime
|
2 |
-
from typing import Dict
|
3 |
|
4 |
import streamlit as st
|
5 |
|
6 |
-
from lib import
|
7 |
|
8 |
# config
|
9 |
st.set_page_config(
|
@@ -32,19 +31,24 @@ if "txt2txt_seed" not in st.session_state:
|
|
32 |
st.logo(config.logo)
|
33 |
st.sidebar.header("Settings")
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
service = st.sidebar.selectbox(
|
36 |
"Service",
|
37 |
-
options=
|
38 |
-
format_func=lambda x:
|
39 |
disabled=st.session_state.running,
|
40 |
)
|
41 |
|
42 |
-
#
|
43 |
-
|
44 |
-
for service_id, service_config in config.service.items():
|
45 |
if service == service_id:
|
46 |
session_key = f"api_key_{service}"
|
47 |
-
api_key =
|
48 |
st.session_state[session_key] = st.sidebar.text_input(
|
49 |
"API Key",
|
50 |
type="password",
|
@@ -53,33 +57,33 @@ for service_id, service_config in config.service.items():
|
|
53 |
help="Set by environment variable" if api_key else "Cleared on page refresh",
|
54 |
)
|
55 |
|
56 |
-
|
57 |
|
58 |
model = st.sidebar.selectbox(
|
59 |
"Model",
|
60 |
-
options=
|
61 |
-
format_func=lambda x:
|
62 |
disabled=st.session_state.running,
|
63 |
)
|
64 |
|
65 |
-
|
66 |
|
67 |
system = st.sidebar.text_area(
|
68 |
"System Message",
|
69 |
-
value=
|
70 |
disabled=st.session_state.running,
|
71 |
)
|
72 |
|
73 |
# build parameters from preset
|
74 |
parameters = {}
|
75 |
-
for param in
|
76 |
if param == "max_tokens":
|
77 |
parameters[param] = st.sidebar.slider(
|
78 |
"Max Tokens",
|
79 |
step=512,
|
80 |
-
value=
|
81 |
-
min_value=
|
82 |
-
max_value=
|
83 |
disabled=st.session_state.running,
|
84 |
help="Maximum number of tokens to generate (default: 512)",
|
85 |
)
|
@@ -87,9 +91,9 @@ for param in model_preset.parameters:
|
|
87 |
parameters[param] = st.sidebar.slider(
|
88 |
"Temperature",
|
89 |
step=0.1,
|
90 |
-
value=
|
91 |
-
min_value=
|
92 |
-
max_value=
|
93 |
disabled=st.session_state.running,
|
94 |
help="Used to modulate the next token probabilities (default: 1.0)",
|
95 |
)
|
@@ -97,9 +101,9 @@ for param in model_preset.parameters:
|
|
97 |
parameters[param] = st.sidebar.slider(
|
98 |
"Frequency Penalty",
|
99 |
step=0.1,
|
100 |
-
value=
|
101 |
-
min_value=
|
102 |
-
max_value=
|
103 |
disabled=st.session_state.running,
|
104 |
help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
|
105 |
)
|
@@ -177,7 +181,7 @@ if prompt := st.chat_input(
|
|
177 |
|
178 |
with st.chat_message("assistant"):
|
179 |
session_key = f"api_key_{service}"
|
180 |
-
api_key = st.session_state[session_key] or
|
181 |
response = txt2txt_generate(api_key, service, model, parameters)
|
182 |
st.session_state.running = False
|
183 |
|
|
|
1 |
from datetime import datetime
|
|
|
2 |
|
3 |
import streamlit as st
|
4 |
|
5 |
+
from lib import config, txt2txt_generate
|
6 |
|
7 |
# config
|
8 |
st.set_page_config(
|
|
|
31 |
st.logo(config.logo)
|
32 |
st.sidebar.header("Settings")
|
33 |
|
34 |
+
text_services = {
|
35 |
+
service_id: service_config
|
36 |
+
for service_id, service_config in config.services.items()
|
37 |
+
if getattr(service_config, "text", None)
|
38 |
+
}
|
39 |
+
|
40 |
service = st.sidebar.selectbox(
|
41 |
"Service",
|
42 |
+
options=text_services.keys(),
|
43 |
+
format_func=lambda x: text_services[x].name,
|
44 |
disabled=st.session_state.running,
|
45 |
)
|
46 |
|
47 |
+
# Show the API key input for the selected service.
|
48 |
+
for service_id, service_preset in text_services.items():
|
|
|
49 |
if service == service_id:
|
50 |
session_key = f"api_key_{service}"
|
51 |
+
api_key = service_preset.api_key
|
52 |
st.session_state[session_key] = st.sidebar.text_input(
|
53 |
"API Key",
|
54 |
type="password",
|
|
|
57 |
help="Set by environment variable" if api_key else "Cleared on page refresh",
|
58 |
)
|
59 |
|
60 |
+
service_config = text_services[service]
|
61 |
|
62 |
model = st.sidebar.selectbox(
|
63 |
"Model",
|
64 |
+
options=service_config.text.keys(),
|
65 |
+
format_func=lambda x: service_config.text[x].name,
|
66 |
disabled=st.session_state.running,
|
67 |
)
|
68 |
|
69 |
+
model_config = service_config.text[model]
|
70 |
|
71 |
system = st.sidebar.text_area(
|
72 |
"System Message",
|
73 |
+
value=model_config.system_prompt,
|
74 |
disabled=st.session_state.running,
|
75 |
)
|
76 |
|
77 |
# build parameters from preset
|
78 |
parameters = {}
|
79 |
+
for param in model_config.parameters:
|
80 |
if param == "max_tokens":
|
81 |
parameters[param] = st.sidebar.slider(
|
82 |
"Max Tokens",
|
83 |
step=512,
|
84 |
+
value=model_config.max_tokens,
|
85 |
+
min_value=model_config.max_tokens_range[0],
|
86 |
+
max_value=model_config.max_tokens_range[1],
|
87 |
disabled=st.session_state.running,
|
88 |
help="Maximum number of tokens to generate (default: 512)",
|
89 |
)
|
|
|
91 |
parameters[param] = st.sidebar.slider(
|
92 |
"Temperature",
|
93 |
step=0.1,
|
94 |
+
value=model_config.temperature,
|
95 |
+
min_value=model_config.temperature_range[0],
|
96 |
+
max_value=model_config.temperature_range[1],
|
97 |
disabled=st.session_state.running,
|
98 |
help="Used to modulate the next token probabilities (default: 1.0)",
|
99 |
)
|
|
|
101 |
parameters[param] = st.sidebar.slider(
|
102 |
"Frequency Penalty",
|
103 |
step=0.1,
|
104 |
+
value=model_config.frequency_penalty,
|
105 |
+
min_value=model_config.frequency_penalty_range[0],
|
106 |
+
max_value=model_config.frequency_penalty_range[1],
|
107 |
disabled=st.session_state.running,
|
108 |
help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
|
109 |
)
|
|
|
181 |
|
182 |
with st.chat_message("assistant"):
|
183 |
session_key = f"api_key_{service}"
|
184 |
+
api_key = st.session_state[session_key] or text_services[service].api_key
|
185 |
response = txt2txt_generate(api_key, service, model, parameters)
|
186 |
st.session_state.running = False
|
187 |
|
pages/2_🎨_Text_to_Image.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
from datetime import datetime
|
2 |
-
from typing import Dict
|
3 |
|
4 |
import streamlit as st
|
5 |
|
6 |
-
from lib import
|
7 |
|
8 |
st.set_page_config(
|
9 |
page_title=f"{config.title} | Text to Image",
|
@@ -36,20 +35,24 @@ if "txt2img_seed" not in st.session_state:
|
|
36 |
st.logo(config.logo)
|
37 |
st.sidebar.header("Settings")
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
service = st.sidebar.selectbox(
|
40 |
"Service",
|
41 |
-
options=
|
42 |
-
format_func=lambda x:
|
43 |
disabled=st.session_state.running,
|
44 |
)
|
45 |
|
46 |
# Show the API key input for the selected service.
|
47 |
-
|
48 |
-
# for display_name, session_key in SERVICE_SESSION.items():
|
49 |
-
for service_id in config.service.keys():
|
50 |
if service == service_id:
|
51 |
session_key = f"api_key_{service}"
|
52 |
-
api_key =
|
53 |
st.session_state[session_key] = st.sidebar.text_input(
|
54 |
"API Key",
|
55 |
type="password",
|
@@ -58,16 +61,16 @@ for service_id in config.service.keys():
|
|
58 |
help="Set by environment variable" if api_key else "Cleared on page refresh",
|
59 |
)
|
60 |
|
61 |
-
|
62 |
|
63 |
model = st.sidebar.selectbox(
|
64 |
"Model",
|
65 |
-
options=
|
66 |
-
format_func=lambda x:
|
67 |
disabled=st.session_state.running,
|
68 |
)
|
69 |
|
70 |
-
|
71 |
|
72 |
# heading
|
73 |
st.html("""
|
@@ -77,7 +80,7 @@ st.html("""
|
|
77 |
|
78 |
# Build parameters from preset by rendering the appropriate input widgets
|
79 |
parameters = {}
|
80 |
-
for param in
|
81 |
if param == "model":
|
82 |
parameters[param] = model
|
83 |
if param == "seed":
|
@@ -91,56 +94,56 @@ for param in model_preset.parameters:
|
|
91 |
if param == "negative_prompt":
|
92 |
parameters[param] = st.sidebar.text_area(
|
93 |
"Negative Prompt",
|
94 |
-
value=
|
95 |
disabled=st.session_state.running,
|
96 |
)
|
97 |
if param == "width":
|
98 |
parameters[param] = st.sidebar.slider(
|
99 |
"Width",
|
100 |
step=64,
|
101 |
-
value=
|
102 |
-
min_value=
|
103 |
-
max_value=
|
104 |
disabled=st.session_state.running,
|
105 |
)
|
106 |
if param == "height":
|
107 |
parameters[param] = st.sidebar.slider(
|
108 |
"Height",
|
109 |
step=64,
|
110 |
-
value=
|
111 |
-
min_value=
|
112 |
-
max_value=
|
113 |
disabled=st.session_state.running,
|
114 |
)
|
115 |
if param == "image_size":
|
116 |
parameters[param] = st.sidebar.select_slider(
|
117 |
"Image Size",
|
118 |
-
options=
|
119 |
-
value=
|
120 |
disabled=st.session_state.running,
|
121 |
)
|
122 |
if param == "aspect_ratio":
|
123 |
parameters[param] = st.sidebar.select_slider(
|
124 |
"Aspect Ratio",
|
125 |
-
options=
|
126 |
-
value=
|
127 |
disabled=st.session_state.running,
|
128 |
)
|
129 |
if param in ["guidance_scale", "guidance"]:
|
130 |
parameters[param] = st.sidebar.slider(
|
131 |
"Guidance Scale",
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
0.1,
|
136 |
disabled=st.session_state.running,
|
137 |
)
|
138 |
if param in ["num_inference_steps", "steps"]:
|
139 |
parameters[param] = st.sidebar.slider(
|
140 |
"Inference Steps",
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
1,
|
145 |
disabled=st.session_state.running,
|
146 |
)
|
@@ -175,7 +178,7 @@ for message in st.session_state.txt2img_messages:
|
|
175 |
filtered_parameters = [
|
176 |
f"`{k}`: {v}"
|
177 |
for k, v in message["parameters"].items()
|
178 |
-
if k not in config.
|
179 |
]
|
180 |
st.markdown(f"`model`: {message['model']}\n\n" + "\n\n".join(filtered_parameters))
|
181 |
|
@@ -248,15 +251,15 @@ if prompt := st.chat_input(
|
|
248 |
|
249 |
with st.chat_message("assistant"):
|
250 |
with st.spinner("Running..."):
|
251 |
-
if
|
252 |
-
parameters.update(
|
253 |
session_key = f"api_key_{service}"
|
254 |
-
api_key = st.session_state[session_key] or
|
255 |
image = txt2img_generate(api_key, service, model, prompt, parameters)
|
256 |
st.session_state.running = False
|
257 |
|
258 |
st.session_state.txt2img_messages.append(
|
259 |
-
{"role": "user", "content": prompt, "parameters": parameters, "model":
|
260 |
)
|
261 |
st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
|
262 |
st.rerun()
|
|
|
1 |
from datetime import datetime
|
|
|
2 |
|
3 |
import streamlit as st
|
4 |
|
5 |
+
from lib import config, txt2img_generate
|
6 |
|
7 |
st.set_page_config(
|
8 |
page_title=f"{config.title} | Text to Image",
|
|
|
35 |
st.logo(config.logo)
|
36 |
st.sidebar.header("Settings")
|
37 |
|
38 |
+
image_services = {
|
39 |
+
service_id: service_config
|
40 |
+
for service_id, service_config in config.services.items()
|
41 |
+
if getattr(service_config, "image", None)
|
42 |
+
}
|
43 |
+
|
44 |
service = st.sidebar.selectbox(
|
45 |
"Service",
|
46 |
+
options=image_services.keys(),
|
47 |
+
format_func=lambda x: image_services[x].name,
|
48 |
disabled=st.session_state.running,
|
49 |
)
|
50 |
|
51 |
# Show the API key input for the selected service.
|
52 |
+
for service_id, service_config in image_services.items():
|
|
|
|
|
53 |
if service == service_id:
|
54 |
session_key = f"api_key_{service}"
|
55 |
+
api_key = service_config.api_key
|
56 |
st.session_state[session_key] = st.sidebar.text_input(
|
57 |
"API Key",
|
58 |
type="password",
|
|
|
61 |
help="Set by environment variable" if api_key else "Cleared on page refresh",
|
62 |
)
|
63 |
|
64 |
+
service_config = image_services[service]
|
65 |
|
66 |
model = st.sidebar.selectbox(
|
67 |
"Model",
|
68 |
+
options=service_config.image.keys(),
|
69 |
+
format_func=lambda x: service_config.image[x].name,
|
70 |
disabled=st.session_state.running,
|
71 |
)
|
72 |
|
73 |
+
model_config = service_config.image[model]
|
74 |
|
75 |
# heading
|
76 |
st.html("""
|
|
|
80 |
|
81 |
# Build parameters from preset by rendering the appropriate input widgets
|
82 |
parameters = {}
|
83 |
+
for param in model_config.parameters:
|
84 |
if param == "model":
|
85 |
parameters[param] = model
|
86 |
if param == "seed":
|
|
|
94 |
if param == "negative_prompt":
|
95 |
parameters[param] = st.sidebar.text_area(
|
96 |
"Negative Prompt",
|
97 |
+
value=model_config.negative_prompt,
|
98 |
disabled=st.session_state.running,
|
99 |
)
|
100 |
if param == "width":
|
101 |
parameters[param] = st.sidebar.slider(
|
102 |
"Width",
|
103 |
step=64,
|
104 |
+
value=model_config.width,
|
105 |
+
min_value=model_config.width_range[0],
|
106 |
+
max_value=model_config.width_range[1],
|
107 |
disabled=st.session_state.running,
|
108 |
)
|
109 |
if param == "height":
|
110 |
parameters[param] = st.sidebar.slider(
|
111 |
"Height",
|
112 |
step=64,
|
113 |
+
value=model_config.height,
|
114 |
+
min_value=model_config.height_range[0],
|
115 |
+
max_value=model_config.height_range[1],
|
116 |
disabled=st.session_state.running,
|
117 |
)
|
118 |
if param == "image_size":
|
119 |
parameters[param] = st.sidebar.select_slider(
|
120 |
"Image Size",
|
121 |
+
options=model_config.image_sizes,
|
122 |
+
value=model_config.image_size,
|
123 |
disabled=st.session_state.running,
|
124 |
)
|
125 |
if param == "aspect_ratio":
|
126 |
parameters[param] = st.sidebar.select_slider(
|
127 |
"Aspect Ratio",
|
128 |
+
options=model_config.aspect_ratios,
|
129 |
+
value=model_config.aspect_ratio,
|
130 |
disabled=st.session_state.running,
|
131 |
)
|
132 |
if param in ["guidance_scale", "guidance"]:
|
133 |
parameters[param] = st.sidebar.slider(
|
134 |
"Guidance Scale",
|
135 |
+
model_config.guidance_scale_range[0],
|
136 |
+
model_config.guidance_scale_range[1],
|
137 |
+
model_config.guidance_scale,
|
138 |
0.1,
|
139 |
disabled=st.session_state.running,
|
140 |
)
|
141 |
if param in ["num_inference_steps", "steps"]:
|
142 |
parameters[param] = st.sidebar.slider(
|
143 |
"Inference Steps",
|
144 |
+
model_config.num_inference_steps_range[0],
|
145 |
+
model_config.num_inference_steps_range[1],
|
146 |
+
model_config.num_inference_steps,
|
147 |
1,
|
148 |
disabled=st.session_state.running,
|
149 |
)
|
|
|
178 |
filtered_parameters = [
|
179 |
f"`{k}`: {v}"
|
180 |
for k, v in message["parameters"].items()
|
181 |
+
if k not in config.hidden_parameters
|
182 |
]
|
183 |
st.markdown(f"`model`: {message['model']}\n\n" + "\n\n".join(filtered_parameters))
|
184 |
|
|
|
251 |
|
252 |
with st.chat_message("assistant"):
|
253 |
with st.spinner("Running..."):
|
254 |
+
if model_config.kwargs:
|
255 |
+
parameters.update(model_config.kwargs)
|
256 |
session_key = f"api_key_{service}"
|
257 |
+
api_key = st.session_state[session_key] or image_services[service].api_key
|
258 |
image = txt2img_generate(api_key, service, model, prompt, parameters)
|
259 |
st.session_state.running = False
|
260 |
|
261 |
st.session_state.txt2img_messages.append(
|
262 |
+
{"role": "user", "content": prompt, "parameters": parameters, "model": model_config.name}
|
263 |
)
|
264 |
st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
|
265 |
st.rerun()
|