adamelliotfields commited on
Commit
b7490f8
·
verified ·
1 Parent(s): c659a88

Move presets to config

Browse files
lib/__init__.py CHANGED
@@ -1,12 +1,8 @@
1
  from .api import txt2img_generate, txt2txt_generate
2
  from .config import config
3
- from .preset import Txt2ImgPreset, Txt2TxtPreset, preset
4
 
5
  __all__ = [
6
- "Txt2ImgPreset",
7
- "Txt2TxtPreset",
8
  "config",
9
- "preset",
10
  "txt2img_generate",
11
  "txt2txt_generate",
12
  ]
 
1
  from .api import txt2img_generate, txt2txt_generate
2
  from .config import config
 
3
 
4
  __all__ = [
 
 
5
  "config",
 
6
  "txt2img_generate",
7
  "txt2txt_generate",
8
  ]
lib/api.py CHANGED
@@ -11,7 +11,7 @@ from .config import config
11
 
12
 
13
  def txt2txt_generate(api_key, service, model, parameters, **kwargs):
14
- base_url = config.service[service].url
15
  if service == "hf":
16
  base_url = f"{base_url}/{model}/v1"
17
  client = OpenAI(api_key=api_key, base_url=base_url)
@@ -52,23 +52,24 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
52
  headers["Authorization"] = f"Bearer {api_key}"
53
  json["prompt"] = inputs
54
 
55
- base_url = config.service[service].url
56
 
57
  if service not in ["together"]:
58
  base_url = f"{base_url}/{model}"
59
 
60
  try:
61
- response = httpx.post(base_url, headers=headers, json=json, timeout=config.txt2img.timeout)
 
62
  if response.status_code // 100 == 2: # 2xx
63
  # BFL is async so we need to poll for result
64
  # https://api.bfl.ml/docs
65
  if service == "bfl":
66
  id = response.json()["id"]
67
- url = f"{config.service[service].url}/get_result?id={id}"
68
 
69
  retries = 0
70
- while retries < config.txt2img.timeout:
71
- response = httpx.get(url, timeout=config.txt2img.timeout)
72
  if response.status_code // 100 != 2:
73
  return f"Error: {response.status_code} {response.text}"
74
 
@@ -76,7 +77,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
76
  image = httpx.get(
77
  response.json()["result"]["sample"],
78
  headers=headers,
79
- timeout=config.txt2img.timeout,
80
  )
81
  return Image.open(io.BytesIO(image.content))
82
 
@@ -92,7 +93,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
92
  return Image.open(io.BytesIO(bytes))
93
  else:
94
  url = response.json()["images"][0]["url"]
95
- image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
96
  return Image.open(io.BytesIO(image.content))
97
 
98
  if service == "hf":
@@ -100,7 +101,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
100
 
101
  if service == "together":
102
  url = response.json()["data"][0]["url"]
103
- image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
104
  return Image.open(io.BytesIO(image.content))
105
 
106
  else:
 
11
 
12
 
13
  def txt2txt_generate(api_key, service, model, parameters, **kwargs):
14
+ base_url = config.services[service].url
15
  if service == "hf":
16
  base_url = f"{base_url}/{model}/v1"
17
  client = OpenAI(api_key=api_key, base_url=base_url)
 
52
  headers["Authorization"] = f"Bearer {api_key}"
53
  json["prompt"] = inputs
54
 
55
+ base_url = config.services[service].url
56
 
57
  if service not in ["together"]:
58
  base_url = f"{base_url}/{model}"
59
 
60
  try:
61
+ timeout = config.timeout
62
+ response = httpx.post(base_url, headers=headers, json=json, timeout=timeout)
63
  if response.status_code // 100 == 2: # 2xx
64
  # BFL is async so we need to poll for result
65
  # https://api.bfl.ml/docs
66
  if service == "bfl":
67
  id = response.json()["id"]
68
+ url = f"{config.services[service].url}/get_result?id={id}"
69
 
70
  retries = 0
71
+ while retries < timeout:
72
+ response = httpx.get(url, timeout=timeout)
73
  if response.status_code // 100 != 2:
74
  return f"Error: {response.status_code} {response.text}"
75
 
 
77
  image = httpx.get(
78
  response.json()["result"]["sample"],
79
  headers=headers,
80
+ timeout=timeout,
81
  )
82
  return Image.open(io.BytesIO(image.content))
83
 
 
93
  return Image.open(io.BytesIO(bytes))
94
  else:
95
  url = response.json()["images"][0]["url"]
96
+ image = httpx.get(url, headers=headers, timeout=timeout)
97
  return Image.open(io.BytesIO(image.content))
98
 
99
  if service == "hf":
 
101
 
102
  if service == "together":
103
  url = response.json()["data"][0]["url"]
104
+ image = httpx.get(url, headers=headers, timeout=timeout)
105
  return Image.open(io.BytesIO(image.content))
106
 
107
  else:
lib/config.py CHANGED
@@ -1,123 +1,415 @@
1
  import os
2
- from dataclasses import dataclass
3
- from typing import Dict, List, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  @dataclass
7
- class ServiceConfig:
8
  name: str
9
- url: str
10
- api_key: Optional[str] = None
11
 
12
 
13
  @dataclass
14
- class Txt2TxtConfig:
15
- default_system: str
 
 
 
 
 
 
16
 
17
 
18
  @dataclass
19
- class Txt2ImgConfig:
20
- hidden_parameters: List[str]
21
- negative_prompt: str
22
- default_image_size: str
23
- image_sizes: List[str]
24
- default_aspect_ratio: str
25
- aspect_ratios: List[str]
26
- timeout: int = 60
 
 
 
 
 
 
27
 
28
 
29
  @dataclass
30
- class Config:
 
 
 
 
 
 
 
 
 
31
  title: str
32
  layout: str
33
  logo: str
34
- service: Dict[str, ServiceConfig]
35
- txt2img: Txt2ImgConfig
36
- txt2txt: Txt2TxtConfig
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- config = Config(
40
  title="API Inference",
41
  layout="wide",
42
  logo="logo.png",
43
- service={
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  "bfl": ServiceConfig(
45
- "Black Forest Labs",
46
- "https://api.bfl.ml/v1",
47
- os.environ.get("BFL_API_KEY"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ),
49
  "fal": ServiceConfig(
50
- "Fal",
51
- "https://fal.run",
52
- os.environ.get("FAL_KEY"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  ),
54
  "hf": ServiceConfig(
55
- "Hugging Face",
56
- "https://api-inference.huggingface.co/models",
57
- os.environ.get("HF_TOKEN"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ),
59
  "pplx": ServiceConfig(
60
- "Perplexity",
61
- "https://api.perplexity.ai",
62
- os.environ.get("PPLX_API_KEY"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  ),
 
64
  "together": ServiceConfig(
65
- "Together",
66
- "https://api.together.xyz/v1/images/generations",
67
- os.environ.get("TOGETHER_API_KEY"),
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ),
69
  },
70
- txt2img=Txt2ImgConfig(
71
- hidden_parameters=[
72
- # Sent to API but not shown in generation parameters accordion
73
- "enable_safety_checker",
74
- "max_sequence_length",
75
- "n",
76
- "num_images",
77
- "output_format",
78
- "performance",
79
- "safety_tolerance",
80
- "scheduler",
81
- "sharpness",
82
- "style",
83
- "styles",
84
- "sync_mode",
85
- ],
86
- negative_prompt="ugly, unattractive, disfigured, deformed, mutated, malformed, blurry, grainy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, sloppy, cluttered",
87
- default_image_size="square_hd",
88
- image_sizes=[
89
- "landscape_16_9",
90
- "landscape_4_3",
91
- "square_hd",
92
- "portrait_4_3",
93
- "portrait_16_9",
94
- ],
95
- default_aspect_ratio="1024x1024", # fooocus aspect ratios
96
- aspect_ratios=[
97
- "704x1408", # 1:2
98
- "704x1344", # 11:21
99
- "768x1344", # 4:7
100
- "768x1280", # 3:5
101
- "832x1216", # 13:19
102
- "832x1152", # 13:18
103
- "896x1152", # 7:9
104
- "896x1088", # 14:17
105
- "960x1088", # 15:17
106
- "960x1024", # 15:16
107
- "1024x1024",
108
- "1024x960", # 16:15
109
- "1088x960", # 17:15
110
- "1088x896", # 17:14
111
- "1152x896", # 9:7
112
- "1152x832", # 18:13
113
- "1216x832", # 19:13
114
- "1280x768", # 5:3
115
- "1344x768", # 7:4
116
- "1344x704", # 21:11
117
- "1408x704", # 2:1
118
- ],
119
- ),
120
- txt2txt=Txt2TxtConfig(
121
- default_system="You are a helpful assistant. Be precise and concise.",
122
- ),
123
  )
 
1
  import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ TEXT_SYSTEM_PROMPT = "You are a helpful assistant. Be precise and concise."
6
+
7
+ IMAGE_NEGATIVE_PROMPT = "ugly, unattractive, disfigured, deformed, mutated, malformed, blurry, grainy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, sloppy, cluttered"
8
+
9
+ IMAGE_IMAGE_SIZES = [
10
+ "landscape_16_9",
11
+ "landscape_4_3",
12
+ "square_hd",
13
+ "portrait_4_3",
14
+ "portrait_16_9",
15
+ ]
16
+
17
+ IMAGE_ASPECT_RATIOS = [
18
+ "704x1408", # 1:2
19
+ "704x1344", # 11:21
20
+ "768x1344", # 4:7
21
+ "768x1280", # 3:5
22
+ "832x1216", # 13:19
23
+ "832x1152", # 13:18
24
+ "896x1152", # 7:9
25
+ "896x1088", # 14:17
26
+ "960x1088", # 15:17
27
+ "960x1024", # 15:16
28
+ "1024x1024",
29
+ "1024x960", # 16:15
30
+ "1088x960", # 17:15
31
+ "1088x896", # 17:14
32
+ "1152x896", # 9:7
33
+ "1152x832", # 18:13
34
+ "1216x832", # 19:13
35
+ "1280x768", # 5:3
36
+ "1344x768", # 7:4
37
+ "1344x704", # 21:11
38
+ "1408x704", # 2:1
39
+ ]
40
+
41
+ IMAGE_RANGE = (256, 1408)
42
 
43
 
44
  @dataclass
45
+ class ModelConfig:
46
  name: str
47
+ parameters: List[str]
48
+ kwargs: Optional[Dict[str, Union[str, int, float, bool]]] = field(default_factory=dict)
49
 
50
 
51
  @dataclass
52
+ class TextModelConfig(ModelConfig):
53
+ system_prompt: Optional[str] = None
54
+ frequency_penalty: Optional[float] = None
55
+ frequency_penalty_range: Optional[tuple[float, float]] = None
56
+ max_tokens: Optional[int] = None
57
+ max_tokens_range: Optional[tuple[int, int]] = None
58
+ temperature: Optional[float] = None
59
+ temperature_range: Optional[tuple[float, float]] = None
60
 
61
 
62
  @dataclass
63
+ class ImageModelConfig(ModelConfig):
64
+ negative_prompt: Optional[str] = None
65
+ width: Optional[int] = None
66
+ width_range: Optional[tuple[int, int]] = None
67
+ height: Optional[int] = None
68
+ height_range: Optional[tuple[int, int]] = None
69
+ image_size: Optional[str] = None
70
+ image_sizes: Optional[List[str]] = field(default_factory=list)
71
+ aspect_ratio: Optional[str] = None
72
+ aspect_ratios: Optional[List[str]] = field(default_factory=list)
73
+ guidance_scale: Optional[float] = None
74
+ guidance_scale_range: Optional[tuple[float, float]] = None
75
+ num_inference_steps: Optional[int] = None
76
+ num_inference_steps_range: Optional[tuple[int, int]] = None
77
 
78
 
79
  @dataclass
80
+ class ServiceConfig:
81
+ name: str
82
+ url: str
83
+ api_key: Optional[str]
84
+ text: Optional[Dict[str, TextModelConfig]] = field(default_factory=dict)
85
+ image: Optional[Dict[str, ImageModelConfig]] = field(default_factory=dict)
86
+
87
+
88
+ @dataclass
89
+ class AppConfig:
90
  title: str
91
  layout: str
92
  logo: str
93
+ timeout: int
94
+ hidden_parameters: List[str]
95
+ services: Dict[str, ServiceConfig]
96
+
97
+
98
+ _hf_text_kwargs = {
99
+ "system_prompt": TEXT_SYSTEM_PROMPT,
100
+ "frequency_penalty": 0.0,
101
+ "frequency_penalty_range": (-2.0, 2.0),
102
+ "max_tokens": 512,
103
+ "max_tokens_range": (512, 4096),
104
+ "temperature": 1.0,
105
+ "temperature_range": (0.0, 2.0),
106
+ "parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
107
+ }
108
 
109
+ _pplx_text_kwargs = {
110
+ "system_prompt": TEXT_SYSTEM_PROMPT,
111
+ "frequency_penalty": 1.0,
112
+ "frequency_penalty_range": (1.0, 2.0),
113
+ "max_tokens": 512,
114
+ "max_tokens_range": (512, 4096),
115
+ "temperature": 1.0,
116
+ "temperature_range": (0.0, 2.0),
117
+ "parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
118
+ }
119
 
120
+ config = AppConfig(
121
  title="API Inference",
122
  layout="wide",
123
  logo="logo.png",
124
+ timeout=60,
125
+ hidden_parameters=[
126
+ # Sent to API but not shown in generation parameters accordion
127
+ "enable_safety_checker",
128
+ "max_sequence_length",
129
+ "n",
130
+ "num_images",
131
+ "output_format",
132
+ "performance",
133
+ "safety_tolerance",
134
+ "scheduler",
135
+ "sharpness",
136
+ "style",
137
+ "styles",
138
+ "sync_mode",
139
+ ],
140
+ services={
141
  "bfl": ServiceConfig(
142
+ name="Black Forest Labs",
143
+ url="https://api.bfl.ml/v1",
144
+ api_key=os.environ.get("BFL_API_KEY"),
145
+ image={
146
+ "flux-pro-1.1": ImageModelConfig(
147
+ "FLUX1.1 Pro",
148
+ width=1024,
149
+ width_range=IMAGE_RANGE,
150
+ height=1024,
151
+ height_range=IMAGE_RANGE,
152
+ parameters=["seed", "width", "height", "prompt_upsampling"],
153
+ kwargs={"safety_tolerance": 6},
154
+ ),
155
+ "flux-pro": ImageModelConfig(
156
+ "FLUX.1 Pro",
157
+ width=1024,
158
+ width_range=IMAGE_RANGE,
159
+ height=1024,
160
+ height_range=IMAGE_RANGE,
161
+ guidance_scale=2.5,
162
+ guidance_scale_range=(1.5, 5.0),
163
+ num_inference_steps=50,
164
+ num_inference_steps_range=(10, 50),
165
+ parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
166
+ kwargs={"safety_tolerance": 6, "interval": 1},
167
+ ),
168
+ "flux-dev": ImageModelConfig(
169
+ "FLUX.1 Dev",
170
+ width=1024,
171
+ width_range=IMAGE_RANGE,
172
+ height=1024,
173
+ height_range=IMAGE_RANGE,
174
+ num_inference_steps=28,
175
+ num_inference_steps_range=(10, 50),
176
+ guidance_scale=3.0,
177
+ guidance_scale_range=(1.5, 5.0),
178
+ parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
179
+ kwargs={"safety_tolerance": 6},
180
+ ),
181
+ },
182
  ),
183
  "fal": ServiceConfig(
184
+ name="Fal",
185
+ url="https://fal.run",
186
+ api_key=os.environ.get("FAL_KEY"),
187
+ image={
188
+ "fal-ai/aura-flow": ImageModelConfig(
189
+ "AuraFlow",
190
+ guidance_scale=3.5,
191
+ guidance_scale_range=(1.0, 10.0),
192
+ num_inference_steps=28,
193
+ num_inference_steps_range=(10, 50),
194
+ parameters=["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
195
+ kwargs={"num_images": 1, "sync_mode": False},
196
+ ),
197
+ "fal-ai/flux-pro/v1.1": ImageModelConfig(
198
+ "FLUX1.1 Pro",
199
+ parameters=["seed", "image_size"],
200
+ image_size="square_hd",
201
+ image_sizes=IMAGE_IMAGE_SIZES,
202
+ kwargs={
203
+ "num_images": 1,
204
+ "sync_mode": False,
205
+ "safety_tolerance": 6,
206
+ "enable_safety_checker": False,
207
+ },
208
+ ),
209
+ "fal-ai/flux-pro": ImageModelConfig(
210
+ "FLUX.1 Pro",
211
+ image_size="square_hd",
212
+ image_sizes=IMAGE_IMAGE_SIZES,
213
+ guidance_scale=2.5,
214
+ guidance_scale_range=(1.5, 5.0),
215
+ num_inference_steps=40,
216
+ num_inference_steps_range=(10, 50),
217
+ parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
218
+ kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
219
+ ),
220
+ "fal-ai/flux/dev": ImageModelConfig(
221
+ "FLUX.1 Dev",
222
+ image_size="square_hd",
223
+ image_sizes=IMAGE_IMAGE_SIZES,
224
+ num_inference_steps=28,
225
+ num_inference_steps_range=(10, 50),
226
+ guidance_scale=3.0,
227
+ guidance_scale_range=(1.5, 5.0),
228
+ parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
229
+ kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
230
+ ),
231
+ "fal-ai/flux/schnell": ImageModelConfig(
232
+ "FLUX.1 Schnell",
233
+ image_size="square_hd",
234
+ image_sizes=IMAGE_IMAGE_SIZES,
235
+ num_inference_steps=4,
236
+ num_inference_steps_range=(1, 12),
237
+ parameters=["seed", "image_size", "num_inference_steps"],
238
+ kwargs={"num_images": 1, "sync_mode": False, "enable_safety_checker": False},
239
+ ),
240
+ "fal-ai/fooocus": ImageModelConfig(
241
+ "Fooocus",
242
+ aspect_ratio="1024x1024",
243
+ aspect_ratios=IMAGE_ASPECT_RATIOS,
244
+ guidance_scale=4.0,
245
+ guidance_scale_range=(1.0, 15.0),
246
+ parameters=["seed", "negative_prompt", "aspect_ratio", "guidance_scale"],
247
+ # TODO: more of these can be params
248
+ kwargs={
249
+ "num_images": 1,
250
+ "sync_mode": True,
251
+ "enable_safety_checker": False,
252
+ "output_format": "png",
253
+ "sharpness": 2,
254
+ "styles": ["Fooocus Enhance", "Fooocus V2", "Fooocus Sharp"],
255
+ "performance": "Quality",
256
+ },
257
+ ),
258
+ "fal-ai/kolors": ImageModelConfig(
259
+ "Kolors",
260
+ image_size="square_hd",
261
+ image_sizes=IMAGE_IMAGE_SIZES,
262
+ guidance_scale=5.0,
263
+ guidance_scale_range=(1.0, 10.0),
264
+ num_inference_steps=50,
265
+ num_inference_steps_range=(10, 50),
266
+ parameters=[
267
+ "seed",
268
+ "negative_prompt",
269
+ "image_size",
270
+ "guidance_scale",
271
+ "num_inference_steps",
272
+ ],
273
+ kwargs={
274
+ "num_images": 1,
275
+ "sync_mode": True,
276
+ "enable_safety_checker": False,
277
+ "scheduler": "EulerDiscreteScheduler",
278
+ },
279
+ ),
280
+ "fal-ai/stable-diffusion-v3-medium": ImageModelConfig(
281
+ "SD3",
282
+ image_size="square_hd",
283
+ image_sizes=IMAGE_IMAGE_SIZES,
284
+ guidance_scale=5.0,
285
+ guidance_scale_range=(1.0, 10.0),
286
+ num_inference_steps=28,
287
+ num_inference_steps_range=(10, 50),
288
+ parameters=[
289
+ "seed",
290
+ "negative_prompt",
291
+ "image_size",
292
+ "guidance_scale",
293
+ "num_inference_steps",
294
+ "prompt_expansion",
295
+ ],
296
+ kwargs={"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
297
+ ),
298
+ },
299
  ),
300
  "hf": ServiceConfig(
301
+ name="Hugging Face",
302
+ url="https://api-inference.huggingface.co/models",
303
+ api_key=os.environ.get("HF_TOKEN"),
304
+ text={
305
+ "codellama/codellama-34b-instruct-hf": TextModelConfig(
306
+ "Code Llama 34B",
307
+ **_hf_text_kwargs,
308
+ ),
309
+ "meta-llama/llama-2-13b-chat-hf": TextModelConfig(
310
+ "Meta Llama 2 13B",
311
+ **_hf_text_kwargs,
312
+ ),
313
+ "mistralai/mistral-7b-instruct-v0.2": TextModelConfig(
314
+ "Mistral 0.2 7B",
315
+ **_hf_text_kwargs,
316
+ ),
317
+ "nousresearch/nous-hermes-2-mixtral-8x7b-dpo": TextModelConfig(
318
+ "Nous Hermes 2 Mixtral 8x7B",
319
+ **_hf_text_kwargs,
320
+ ),
321
+ },
322
+ image={
323
+ "black-forest-labs/flux.1-dev": ImageModelConfig(
324
+ "FLUX.1 Dev",
325
+ width=1024,
326
+ width_range=IMAGE_RANGE,
327
+ height=1024,
328
+ height_range=IMAGE_RANGE,
329
+ guidance_scale=3.0,
330
+ guidance_scale_range=(1.5, 5.0),
331
+ num_inference_steps=28,
332
+ num_inference_steps_range=(10, 50),
333
+ parameters=["width", "height", "guidance_scale", "num_inference_steps"],
334
+ ),
335
+ "black-forest-labs/flux.1-schnell": ImageModelConfig(
336
+ "FLUX.1 Schnell",
337
+ width=1024,
338
+ width_range=IMAGE_RANGE,
339
+ height=1024,
340
+ height_range=IMAGE_RANGE,
341
+ num_inference_steps=4,
342
+ num_inference_steps_range=(1, 12),
343
+ parameters=["width", "height", "num_inference_steps"],
344
+ kwargs={"guidance_scale": 0.0, "max_sequence_length": 256},
345
+ ),
346
+ "stabilityai/stable-diffusion-xl-base-1.0": ImageModelConfig(
347
+ "Stable Diffusion XL 1.0",
348
+ negative_prompt=IMAGE_NEGATIVE_PROMPT,
349
+ width=1024,
350
+ width_range=IMAGE_RANGE,
351
+ height=1024,
352
+ height_range=IMAGE_RANGE,
353
+ guidance_scale=7.0,
354
+ guidance_scale_range=(1.0, 15.0),
355
+ num_inference_steps=40,
356
+ num_inference_steps_range=(10, 50),
357
+ parameters=[
358
+ "seed",
359
+ "negative_prompt",
360
+ "width",
361
+ "height",
362
+ "guidance_scale",
363
+ "num_inference_steps",
364
+ ],
365
+ ),
366
+ },
367
  ),
368
  "pplx": ServiceConfig(
369
+ name="Perplexity",
370
+ url="https://api.perplexity.ai",
371
+ api_key=os.environ.get("PPLX_API_KEY"),
372
+ text={
373
+ "llama-3.1-sonar-small-128k-chat": TextModelConfig(
374
+ "Sonar Small (Offline)",
375
+ **_pplx_text_kwargs,
376
+ ),
377
+ "llama-3.1-sonar-large-128k-chat": TextModelConfig(
378
+ "Sonar Large (Offline)",
379
+ **_pplx_text_kwargs,
380
+ ),
381
+ "llama-3.1-sonar-small-128k-online": TextModelConfig(
382
+ "Sonar Small (Online)",
383
+ **_pplx_text_kwargs,
384
+ ),
385
+ "llama-3.1-sonar-large-128k-online": TextModelConfig(
386
+ "Sonar Large (Online)",
387
+ **_pplx_text_kwargs,
388
+ ),
389
+ "llama-3.1-sonar-huge-128k-online": TextModelConfig(
390
+ "Sonar Huge (Online)",
391
+ **_pplx_text_kwargs,
392
+ ),
393
+ },
394
  ),
395
+ # TODO: text models
396
  "together": ServiceConfig(
397
+ name="Together",
398
+ url="https://api.together.xyz/v1/images/generations",
399
+ api_key=os.environ.get("TOGETHER_API_KEY"),
400
+ image={
401
+ "black-forest-labs/FLUX.1-schnell-Free": ImageModelConfig(
402
+ "FLUX.1 Schnell Free",
403
+ width=1024,
404
+ width_range=IMAGE_RANGE,
405
+ height=1024,
406
+ height_range=IMAGE_RANGE,
407
+ num_inference_steps=4,
408
+ num_inference_steps_range=(1, 12),
409
+ parameters=["model", "seed", "width", "height", "steps"],
410
+ kwargs={"n": 1},
411
+ ),
412
+ },
413
  ),
414
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  )
lib/preset.py DELETED
@@ -1,250 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import Dict, List, Optional, Union
3
-
4
-
5
- @dataclass
6
- class Txt2TxtPreset:
7
- name: str
8
- frequency_penalty: float
9
- frequency_penalty_min: float
10
- frequency_penalty_max: float
11
- parameters: Optional[List[str]] = field(default_factory=list)
12
-
13
-
14
- @dataclass
15
- class Txt2ImgPreset:
16
- name: str
17
- guidance_scale: Optional[float] = None
18
- guidance_scale_min: Optional[float] = None
19
- guidance_scale_max: Optional[float] = None
20
- num_inference_steps: Optional[int] = None
21
- num_inference_steps_min: Optional[int] = None
22
- num_inference_steps_max: Optional[int] = None
23
- parameters: Optional[List[str]] = field(default_factory=list)
24
- kwargs: Optional[Dict[str, Union[str, int, float, bool]]] = field(default_factory=dict)
25
-
26
-
27
- @dataclass
28
- class Preset:
29
- txt2txt: Dict[str, Txt2TxtPreset]
30
- txt2img: Dict[str, Txt2ImgPreset]
31
-
32
-
33
- hf_txt2txt_kwargs = {
34
- "frequency_penalty": 0.0,
35
- "frequency_penalty_min": -2.0,
36
- "frequency_penalty_max": 2.0,
37
- "parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
38
- }
39
-
40
- pplx_txt2txt_kwargs = {
41
- "frequency_penalty": 1.0,
42
- "frequency_penalty_min": 1.0,
43
- "frequency_penalty_max": 2.0,
44
- "parameters": ["max_tokens", "temperature", "frequency_penalty"],
45
- }
46
-
47
-
48
- preset = Preset(
49
- txt2txt={
50
- "hf": {
51
- # TODO: update models
52
- "codellama/codellama-34b-instruct-hf": Txt2TxtPreset("Code Llama 34B", **hf_txt2txt_kwargs),
53
- "meta-llama/llama-2-13b-chat-hf": Txt2TxtPreset("Llama 2 13B", **hf_txt2txt_kwargs),
54
- "mistralai/mistral-7b-instruct-v0.2": Txt2TxtPreset("Mistral v0.2 7B", **hf_txt2txt_kwargs),
55
- "nousresearch/nous-hermes-2-mixtral-8x7b-dpo": Txt2TxtPreset(
56
- "Nous Hermes 2 Mixtral 8x7B",
57
- **hf_txt2txt_kwargs,
58
- ),
59
- },
60
- "pplx": {
61
- "llama-3.1-sonar-small-128k-chat": Txt2TxtPreset("Sonar Small (Offline)", **pplx_txt2txt_kwargs),
62
- "llama-3.1-sonar-large-128k-chat": Txt2TxtPreset("Sonar Large (Offline)", **pplx_txt2txt_kwargs),
63
- "llama-3.1-sonar-small-128k-online": Txt2TxtPreset("Sonar Small (Online)", **pplx_txt2txt_kwargs),
64
- "llama-3.1-sonar-large-128k-online": Txt2TxtPreset("Sonar Large (Online)", **pplx_txt2txt_kwargs),
65
- "llama-3.1-sonar-huge-128k-online": Txt2TxtPreset("Sonar Huge (Online)", **pplx_txt2txt_kwargs),
66
- },
67
- },
68
- txt2img={
69
- "bfl": {
70
- "flux-pro-1.1": Txt2ImgPreset(
71
- "FLUX1.1 Pro",
72
- parameters=["seed", "width", "height", "prompt_upsampling"],
73
- kwargs={"safety_tolerance": 6},
74
- ),
75
- "flux-pro": Txt2ImgPreset(
76
- "FLUX.1 Pro",
77
- guidance_scale=2.5,
78
- guidance_scale_min=1.5,
79
- guidance_scale_max=5.0,
80
- num_inference_steps=40,
81
- num_inference_steps_min=10,
82
- num_inference_steps_max=50,
83
- parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
84
- kwargs={"safety_tolerance": 6, "interval": 1},
85
- ),
86
- "flux-dev": Txt2ImgPreset(
87
- "FLUX.1 Dev",
88
- num_inference_steps=28,
89
- num_inference_steps_min=10,
90
- num_inference_steps_max=50,
91
- guidance_scale=3.0,
92
- guidance_scale_min=1.5,
93
- guidance_scale_max=5.0,
94
- parameters=["seed", "width", "height", "steps", "guidance", "prompt_upsampling"],
95
- kwargs={"safety_tolerance": 6},
96
- ),
97
- },
98
- "fal": {
99
- "fal-ai/aura-flow": Txt2ImgPreset(
100
- "AuraFlow",
101
- guidance_scale=3.5,
102
- guidance_scale_min=1.0,
103
- guidance_scale_max=10.0,
104
- num_inference_steps=28,
105
- num_inference_steps_min=10,
106
- num_inference_steps_max=50,
107
- parameters=["seed", "num_inference_steps", "guidance_scale", "expand_prompt"],
108
- kwargs={"num_images": 1, "sync_mode": False},
109
- ),
110
- "fal-ai/flux-pro/v1.1": Txt2ImgPreset(
111
- "FLUX1.1 Pro",
112
- parameters=["seed", "image_size"],
113
- kwargs={
114
- "num_images": 1,
115
- "sync_mode": False,
116
- "safety_tolerance": 6,
117
- "enable_safety_checker": False,
118
- },
119
- ),
120
- "fal-ai/flux-pro": Txt2ImgPreset(
121
- "FLUX.1 Pro",
122
- guidance_scale=2.5,
123
- guidance_scale_min=1.5,
124
- guidance_scale_max=5.0,
125
- num_inference_steps=40,
126
- num_inference_steps_min=10,
127
- num_inference_steps_max=50,
128
- parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
129
- kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
130
- ),
131
- "fal-ai/flux/dev": Txt2ImgPreset(
132
- "FLUX.1 Dev",
133
- num_inference_steps=28,
134
- num_inference_steps_min=10,
135
- num_inference_steps_max=50,
136
- guidance_scale=3.0,
137
- guidance_scale_min=1.5,
138
- guidance_scale_max=5.0,
139
- parameters=["seed", "image_size", "num_inference_steps", "guidance_scale"],
140
- kwargs={"num_images": 1, "sync_mode": False, "safety_tolerance": 6},
141
- ),
142
- "fal-ai/flux/schnell": Txt2ImgPreset(
143
- "FLUX.1 Schnell",
144
- num_inference_steps=4,
145
- num_inference_steps_min=1,
146
- num_inference_steps_max=12,
147
- parameters=["seed", "image_size", "num_inference_steps"],
148
- kwargs={"num_images": 1, "sync_mode": False, "enable_safety_checker": False},
149
- ),
150
- "fal-ai/fooocus": Txt2ImgPreset(
151
- "Fooocus",
152
- guidance_scale=4.0,
153
- guidance_scale_min=1.0,
154
- guidance_scale_max=10.0,
155
- parameters=["seed", "negative_prompt", "aspect_ratio", "guidance_scale"],
156
- kwargs={
157
- "num_images": 1,
158
- "sync_mode": True,
159
- "enable_safety_checker": False,
160
- "output_format": "png",
161
- "sharpness": 2,
162
- "styles": ["Fooocus Enhance", "Fooocus V2", "Fooocus Sharp"],
163
- "performance": "Quality",
164
- },
165
- ),
166
- "fal-ai/kolors": Txt2ImgPreset(
167
- "Kolors",
168
- guidance_scale=5.0,
169
- guidance_scale_min=1.0,
170
- guidance_scale_max=10.0,
171
- num_inference_steps=50,
172
- num_inference_steps_min=10,
173
- num_inference_steps_max=50,
174
- parameters=["seed", "negative_prompt", "image_size", "guidance_scale", "num_inference_steps"],
175
- kwargs={
176
- "num_images": 1,
177
- "sync_mode": True,
178
- "enable_safety_checker": False,
179
- "scheduler": "EulerDiscreteScheduler",
180
- },
181
- ),
182
- "fal-ai/stable-diffusion-v3-medium": Txt2ImgPreset(
183
- "SD3",
184
- guidance_scale=5.0,
185
- guidance_scale_min=1.0,
186
- guidance_scale_max=10.0,
187
- num_inference_steps=28,
188
- num_inference_steps_min=10,
189
- num_inference_steps_max=50,
190
- parameters=[
191
- "seed",
192
- "negative_prompt",
193
- "image_size",
194
- "guidance_scale",
195
- "num_inference_steps",
196
- "prompt_expansion",
197
- ],
198
- kwargs={"num_images": 1, "sync_mode": True, "enable_safety_checker": False},
199
- ),
200
- },
201
- "hf": {
202
- "black-forest-labs/flux.1-dev": Txt2ImgPreset(
203
- "FLUX.1 Dev",
204
- num_inference_steps=28,
205
- num_inference_steps_min=10,
206
- num_inference_steps_max=50,
207
- guidance_scale=3.0,
208
- guidance_scale_min=1.5,
209
- guidance_scale_max=5.0,
210
- parameters=["width", "height", "guidance_scale", "num_inference_steps"],
211
- kwargs={"max_sequence_length": 512},
212
- ),
213
- "black-forest-labs/flux.1-schnell": Txt2ImgPreset(
214
- "FLUX.1 Schnell",
215
- num_inference_steps=4,
216
- num_inference_steps_min=1,
217
- num_inference_steps_max=12,
218
- parameters=["width", "height", "num_inference_steps"],
219
- kwargs={"guidance_scale": 0.0, "max_sequence_length": 256},
220
- ),
221
- "stabilityai/stable-diffusion-xl-base-1.0": Txt2ImgPreset(
222
- "SDXL",
223
- guidance_scale=7.0,
224
- guidance_scale_min=1.0,
225
- guidance_scale_max=10.0,
226
- num_inference_steps=40,
227
- num_inference_steps_min=10,
228
- num_inference_steps_max=50,
229
- parameters=[
230
- "seed",
231
- "negative_prompt",
232
- "width",
233
- "height",
234
- "guidance_scale",
235
- "num_inference_steps",
236
- ],
237
- ),
238
- },
239
- "together": {
240
- "black-forest-labs/FLUX.1-schnell-Free": Txt2ImgPreset(
241
- "FLUX.1 Schnell Free",
242
- num_inference_steps=4,
243
- num_inference_steps_min=1,
244
- num_inference_steps_max=12,
245
- parameters=["model", "seed", "width", "height", "steps"],
246
- kwargs={"n": 1},
247
- ),
248
- },
249
- },
250
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/1_💬_Text_Generation.py CHANGED
@@ -1,9 +1,8 @@
1
  from datetime import datetime
2
- from typing import Dict
3
 
4
  import streamlit as st
5
 
6
- from lib import Txt2TxtPreset, config, preset, txt2txt_generate
7
 
8
  # config
9
  st.set_page_config(
@@ -32,19 +31,24 @@ if "txt2txt_seed" not in st.session_state:
32
  st.logo(config.logo)
33
  st.sidebar.header("Settings")
34
 
 
 
 
 
 
 
35
  service = st.sidebar.selectbox(
36
  "Service",
37
- options=preset.txt2txt.keys(),
38
- format_func=lambda x: config.service[x].name,
39
  disabled=st.session_state.running,
40
  )
41
 
42
- # disable API key input and hide value if set by environment variable (handle empty string value later)
43
- # for display_name, session_key in SERVICE_SESSION.items():
44
- for service_id, service_config in config.service.items():
45
  if service == service_id:
46
  session_key = f"api_key_{service}"
47
- api_key = config.service[service].api_key
48
  st.session_state[session_key] = st.sidebar.text_input(
49
  "API Key",
50
  type="password",
@@ -53,33 +57,33 @@ for service_id, service_config in config.service.items():
53
  help="Set by environment variable" if api_key else "Cleared on page refresh",
54
  )
55
 
56
- service_preset: Dict[str, Txt2TxtPreset] = preset.txt2txt[service]
57
 
58
  model = st.sidebar.selectbox(
59
  "Model",
60
- options=service_preset.keys(),
61
- format_func=lambda x: service_preset[x].name,
62
  disabled=st.session_state.running,
63
  )
64
 
65
- model_preset = service_preset[model]
66
 
67
  system = st.sidebar.text_area(
68
  "System Message",
69
- value=config.txt2txt.default_system,
70
  disabled=st.session_state.running,
71
  )
72
 
73
  # build parameters from preset
74
  parameters = {}
75
- for param in model_preset.parameters:
76
  if param == "max_tokens":
77
  parameters[param] = st.sidebar.slider(
78
  "Max Tokens",
79
  step=512,
80
- value=512,
81
- min_value=512,
82
- max_value=4096,
83
  disabled=st.session_state.running,
84
  help="Maximum number of tokens to generate (default: 512)",
85
  )
@@ -87,9 +91,9 @@ for param in model_preset.parameters:
87
  parameters[param] = st.sidebar.slider(
88
  "Temperature",
89
  step=0.1,
90
- value=1.0,
91
- min_value=0.0,
92
- max_value=2.0,
93
  disabled=st.session_state.running,
94
  help="Used to modulate the next token probabilities (default: 1.0)",
95
  )
@@ -97,9 +101,9 @@ for param in model_preset.parameters:
97
  parameters[param] = st.sidebar.slider(
98
  "Frequency Penalty",
99
  step=0.1,
100
- value=model_preset.frequency_penalty,
101
- min_value=model_preset.frequency_penalty_min,
102
- max_value=model_preset.frequency_penalty_max,
103
  disabled=st.session_state.running,
104
  help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
105
  )
@@ -177,7 +181,7 @@ if prompt := st.chat_input(
177
 
178
  with st.chat_message("assistant"):
179
  session_key = f"api_key_{service}"
180
- api_key = st.session_state[session_key] or config.service[service].api_key
181
  response = txt2txt_generate(api_key, service, model, parameters)
182
  st.session_state.running = False
183
 
 
1
  from datetime import datetime
 
2
 
3
  import streamlit as st
4
 
5
+ from lib import config, txt2txt_generate
6
 
7
  # config
8
  st.set_page_config(
 
31
  st.logo(config.logo)
32
  st.sidebar.header("Settings")
33
 
34
+ text_services = {
35
+ service_id: service_config
36
+ for service_id, service_config in config.services.items()
37
+ if getattr(service_config, "text", None)
38
+ }
39
+
40
  service = st.sidebar.selectbox(
41
  "Service",
42
+ options=text_services.keys(),
43
+ format_func=lambda x: text_services[x].name,
44
  disabled=st.session_state.running,
45
  )
46
 
47
+ # Show the API key input for the selected service.
48
+ for service_id, service_preset in text_services.items():
 
49
  if service == service_id:
50
  session_key = f"api_key_{service}"
51
+ api_key = service_preset.api_key
52
  st.session_state[session_key] = st.sidebar.text_input(
53
  "API Key",
54
  type="password",
 
57
  help="Set by environment variable" if api_key else "Cleared on page refresh",
58
  )
59
 
60
+ service_config = text_services[service]
61
 
62
  model = st.sidebar.selectbox(
63
  "Model",
64
+ options=service_config.text.keys(),
65
+ format_func=lambda x: service_config.text[x].name,
66
  disabled=st.session_state.running,
67
  )
68
 
69
+ model_config = service_config.text[model]
70
 
71
  system = st.sidebar.text_area(
72
  "System Message",
73
+ value=model_config.system_prompt,
74
  disabled=st.session_state.running,
75
  )
76
 
77
  # build parameters from preset
78
  parameters = {}
79
+ for param in model_config.parameters:
80
  if param == "max_tokens":
81
  parameters[param] = st.sidebar.slider(
82
  "Max Tokens",
83
  step=512,
84
+ value=model_config.max_tokens,
85
+ min_value=model_config.max_tokens_range[0],
86
+ max_value=model_config.max_tokens_range[1],
87
  disabled=st.session_state.running,
88
  help="Maximum number of tokens to generate (default: 512)",
89
  )
 
91
  parameters[param] = st.sidebar.slider(
92
  "Temperature",
93
  step=0.1,
94
+ value=model_config.temperature,
95
+ min_value=model_config.temperature_range[0],
96
+ max_value=model_config.temperature_range[1],
97
  disabled=st.session_state.running,
98
  help="Used to modulate the next token probabilities (default: 1.0)",
99
  )
 
101
  parameters[param] = st.sidebar.slider(
102
  "Frequency Penalty",
103
  step=0.1,
104
+ value=model_config.frequency_penalty,
105
+ min_value=model_config.frequency_penalty_range[0],
106
+ max_value=model_config.frequency_penalty_range[1],
107
  disabled=st.session_state.running,
108
  help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
109
  )
 
181
 
182
  with st.chat_message("assistant"):
183
  session_key = f"api_key_{service}"
184
+ api_key = st.session_state[session_key] or text_services[service].api_key
185
  response = txt2txt_generate(api_key, service, model, parameters)
186
  st.session_state.running = False
187
 
pages/2_🎨_Text_to_Image.py CHANGED
@@ -1,9 +1,8 @@
1
  from datetime import datetime
2
- from typing import Dict
3
 
4
  import streamlit as st
5
 
6
- from lib import Txt2ImgPreset, config, preset, txt2img_generate
7
 
8
  st.set_page_config(
9
  page_title=f"{config.title} | Text to Image",
@@ -36,20 +35,24 @@ if "txt2img_seed" not in st.session_state:
36
  st.logo(config.logo)
37
  st.sidebar.header("Settings")
38
 
 
 
 
 
 
 
39
  service = st.sidebar.selectbox(
40
  "Service",
41
- options=preset.txt2img.keys(),
42
- format_func=lambda x: config.service[x].name,
43
  disabled=st.session_state.running,
44
  )
45
 
46
  # Show the API key input for the selected service.
47
- # Disable and hide value if set by environment variable; handle empty string value later.
48
- # for display_name, session_key in SERVICE_SESSION.items():
49
- for service_id in config.service.keys():
50
  if service == service_id:
51
  session_key = f"api_key_{service}"
52
- api_key = config.service[service].api_key
53
  st.session_state[session_key] = st.sidebar.text_input(
54
  "API Key",
55
  type="password",
@@ -58,16 +61,16 @@ for service_id in config.service.keys():
58
  help="Set by environment variable" if api_key else "Cleared on page refresh",
59
  )
60
 
61
- service_preset: Dict[str, Txt2ImgPreset] = preset.txt2img[service]
62
 
63
  model = st.sidebar.selectbox(
64
  "Model",
65
- options=service_preset.keys(),
66
- format_func=lambda x: service_preset[x].name,
67
  disabled=st.session_state.running,
68
  )
69
 
70
- model_preset = service_preset[model]
71
 
72
  # heading
73
  st.html("""
@@ -77,7 +80,7 @@ st.html("""
77
 
78
  # Build parameters from preset by rendering the appropriate input widgets
79
  parameters = {}
80
- for param in model_preset.parameters:
81
  if param == "model":
82
  parameters[param] = model
83
  if param == "seed":
@@ -91,56 +94,56 @@ for param in model_preset.parameters:
91
  if param == "negative_prompt":
92
  parameters[param] = st.sidebar.text_area(
93
  "Negative Prompt",
94
- value=config.txt2img.negative_prompt,
95
  disabled=st.session_state.running,
96
  )
97
  if param == "width":
98
  parameters[param] = st.sidebar.slider(
99
  "Width",
100
  step=64,
101
- value=1024,
102
- min_value=512,
103
- max_value=2048,
104
  disabled=st.session_state.running,
105
  )
106
  if param == "height":
107
  parameters[param] = st.sidebar.slider(
108
  "Height",
109
  step=64,
110
- value=1024,
111
- min_value=512,
112
- max_value=2048,
113
  disabled=st.session_state.running,
114
  )
115
  if param == "image_size":
116
  parameters[param] = st.sidebar.select_slider(
117
  "Image Size",
118
- options=config.txt2img.image_sizes,
119
- value=config.txt2img.default_image_size,
120
  disabled=st.session_state.running,
121
  )
122
  if param == "aspect_ratio":
123
  parameters[param] = st.sidebar.select_slider(
124
  "Aspect Ratio",
125
- options=config.txt2img.aspect_ratios,
126
- value=config.txt2img.default_aspect_ratio,
127
  disabled=st.session_state.running,
128
  )
129
  if param in ["guidance_scale", "guidance"]:
130
  parameters[param] = st.sidebar.slider(
131
  "Guidance Scale",
132
- model_preset.guidance_scale_min,
133
- model_preset.guidance_scale_max,
134
- model_preset.guidance_scale,
135
  0.1,
136
  disabled=st.session_state.running,
137
  )
138
  if param in ["num_inference_steps", "steps"]:
139
  parameters[param] = st.sidebar.slider(
140
  "Inference Steps",
141
- model_preset.num_inference_steps_min,
142
- model_preset.num_inference_steps_max,
143
- model_preset.num_inference_steps,
144
  1,
145
  disabled=st.session_state.running,
146
  )
@@ -175,7 +178,7 @@ for message in st.session_state.txt2img_messages:
175
  filtered_parameters = [
176
  f"`{k}`: {v}"
177
  for k, v in message["parameters"].items()
178
- if k not in config.txt2img.hidden_parameters
179
  ]
180
  st.markdown(f"`model`: {message['model']}\n\n" + "\n\n".join(filtered_parameters))
181
 
@@ -248,15 +251,15 @@ if prompt := st.chat_input(
248
 
249
  with st.chat_message("assistant"):
250
  with st.spinner("Running..."):
251
- if model_preset.kwargs:
252
- parameters.update(model_preset.kwargs)
253
  session_key = f"api_key_{service}"
254
- api_key = st.session_state[session_key] or config.service[service].api_key
255
  image = txt2img_generate(api_key, service, model, prompt, parameters)
256
  st.session_state.running = False
257
 
258
  st.session_state.txt2img_messages.append(
259
- {"role": "user", "content": prompt, "parameters": parameters, "model": model_preset.name}
260
  )
261
  st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
262
  st.rerun()
 
1
  from datetime import datetime
 
2
 
3
  import streamlit as st
4
 
5
+ from lib import config, txt2img_generate
6
 
7
  st.set_page_config(
8
  page_title=f"{config.title} | Text to Image",
 
35
  st.logo(config.logo)
36
  st.sidebar.header("Settings")
37
 
38
+ image_services = {
39
+ service_id: service_config
40
+ for service_id, service_config in config.services.items()
41
+ if getattr(service_config, "image", None)
42
+ }
43
+
44
  service = st.sidebar.selectbox(
45
  "Service",
46
+ options=image_services.keys(),
47
+ format_func=lambda x: image_services[x].name,
48
  disabled=st.session_state.running,
49
  )
50
 
51
  # Show the API key input for the selected service.
52
+ for service_id, service_config in image_services.items():
 
 
53
  if service == service_id:
54
  session_key = f"api_key_{service}"
55
+ api_key = service_config.api_key
56
  st.session_state[session_key] = st.sidebar.text_input(
57
  "API Key",
58
  type="password",
 
61
  help="Set by environment variable" if api_key else "Cleared on page refresh",
62
  )
63
 
64
+ service_config = image_services[service]
65
 
66
  model = st.sidebar.selectbox(
67
  "Model",
68
+ options=service_config.image.keys(),
69
+ format_func=lambda x: service_config.image[x].name,
70
  disabled=st.session_state.running,
71
  )
72
 
73
+ model_config = service_config.image[model]
74
 
75
  # heading
76
  st.html("""
 
80
 
81
  # Build parameters from preset by rendering the appropriate input widgets
82
  parameters = {}
83
+ for param in model_config.parameters:
84
  if param == "model":
85
  parameters[param] = model
86
  if param == "seed":
 
94
  if param == "negative_prompt":
95
  parameters[param] = st.sidebar.text_area(
96
  "Negative Prompt",
97
+ value=model_config.negative_prompt,
98
  disabled=st.session_state.running,
99
  )
100
  if param == "width":
101
  parameters[param] = st.sidebar.slider(
102
  "Width",
103
  step=64,
104
+ value=model_config.width,
105
+ min_value=model_config.width_range[0],
106
+ max_value=model_config.width_range[1],
107
  disabled=st.session_state.running,
108
  )
109
  if param == "height":
110
  parameters[param] = st.sidebar.slider(
111
  "Height",
112
  step=64,
113
+ value=model_config.height,
114
+ min_value=model_config.height_range[0],
115
+ max_value=model_config.height_range[1],
116
  disabled=st.session_state.running,
117
  )
118
  if param == "image_size":
119
  parameters[param] = st.sidebar.select_slider(
120
  "Image Size",
121
+ options=model_config.image_sizes,
122
+ value=model_config.image_size,
123
  disabled=st.session_state.running,
124
  )
125
  if param == "aspect_ratio":
126
  parameters[param] = st.sidebar.select_slider(
127
  "Aspect Ratio",
128
+ options=model_config.aspect_ratios,
129
+ value=model_config.aspect_ratio,
130
  disabled=st.session_state.running,
131
  )
132
  if param in ["guidance_scale", "guidance"]:
133
  parameters[param] = st.sidebar.slider(
134
  "Guidance Scale",
135
+ model_config.guidance_scale_range[0],
136
+ model_config.guidance_scale_range[1],
137
+ model_config.guidance_scale,
138
  0.1,
139
  disabled=st.session_state.running,
140
  )
141
  if param in ["num_inference_steps", "steps"]:
142
  parameters[param] = st.sidebar.slider(
143
  "Inference Steps",
144
+ model_config.num_inference_steps_range[0],
145
+ model_config.num_inference_steps_range[1],
146
+ model_config.num_inference_steps,
147
  1,
148
  disabled=st.session_state.running,
149
  )
 
178
  filtered_parameters = [
179
  f"`{k}`: {v}"
180
  for k, v in message["parameters"].items()
181
+ if k not in config.hidden_parameters
182
  ]
183
  st.markdown(f"`model`: {message['model']}\n\n" + "\n\n".join(filtered_parameters))
184
 
 
251
 
252
  with st.chat_message("assistant"):
253
  with st.spinner("Running..."):
254
+ if model_config.kwargs:
255
+ parameters.update(model_config.kwargs)
256
  session_key = f"api_key_{service}"
257
+ api_key = st.session_state[session_key] or image_services[service].api_key
258
  image = txt2img_generate(api_key, service, model, prompt, parameters)
259
  st.session_state.running = False
260
 
261
  st.session_state.txt2img_messages.append(
262
+ {"role": "user", "content": prompt, "parameters": parameters, "model": model_config.name}
263
  )
264
  st.session_state.txt2img_messages.append({"role": "assistant", "content": image})
265
  st.rerun()