adamelliotfields commited on
Commit
1367e6b
β€’
1 Parent(s): ecbcd62

Convert config to dataclass

Browse files
0_🏠_Home.py CHANGED
@@ -1,11 +1,11 @@
1
  import streamlit as st
2
 
3
- from lib import Config
4
 
5
  st.set_page_config(
6
- page_title=Config.TITLE,
7
- page_icon=Config.ICON,
8
- layout=Config.LAYOUT,
9
  )
10
 
11
  # sidebar
 
1
  import streamlit as st
2
 
3
+ from lib import config
4
 
5
  st.set_page_config(
6
+ page_title=config.title,
7
+ page_icon=config.icon,
8
+ layout=config.layout,
9
  )
10
 
11
  # sidebar
lib/__init__.py CHANGED
@@ -1,9 +1,9 @@
1
  from .api import txt2img_generate, txt2txt_generate
2
- from .config import Config
3
  from .presets import ModelPresets, ServicePresets
4
 
5
  __all__ = [
6
- "Config",
7
  "ModelPresets",
8
  "ServicePresets",
9
  "txt2img_generate",
 
1
  from .api import txt2img_generate, txt2txt_generate
2
+ from .config import config
3
  from .presets import ModelPresets, ServicePresets
4
 
5
  __all__ = [
6
+ "config",
7
  "ModelPresets",
8
  "ServicePresets",
9
  "txt2img_generate",
lib/api.py CHANGED
@@ -7,11 +7,11 @@ import streamlit as st
7
  from openai import APIError, OpenAI
8
  from PIL import Image
9
 
10
- from .config import Config
11
 
12
 
13
  def txt2txt_generate(api_key, service, model, parameters, **kwargs):
14
- base_url = Config.SERVICES[service]
15
  if service == "Hugging Face":
16
  base_url = f"{base_url}/{model}/v1"
17
  client = OpenAI(api_key=api_key, base_url=base_url)
@@ -62,23 +62,23 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
62
  json = {**parameters, **kwargs}
63
  json["prompt"] = inputs
64
 
65
- base_url = Config.SERVICES[service]
66
 
67
  if service not in ["Together"]:
68
  base_url = f"{base_url}/{model}"
69
 
70
  try:
71
- response = httpx.post(base_url, headers=headers, json=json, timeout=Config.TXT2IMG_TIMEOUT)
72
  if response.status_code // 100 == 2: # 2xx
73
  # BFL is async so we need to poll for result
74
  # https://api.bfl.ml/docs
75
  if service == "Black Forest Labs":
76
  id = response.json()["id"]
77
- url = f"{Config.SERVICES[service]}/get_result?id={id}"
78
 
79
  retries = 0
80
- while retries < Config.TXT2IMG_TIMEOUT:
81
- response = httpx.get(url, timeout=Config.TXT2IMG_TIMEOUT)
82
  if response.status_code // 100 != 2:
83
  return f"Error: {response.status_code} {response.text}"
84
 
@@ -86,7 +86,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
86
  image = httpx.get(
87
  response.json()["result"]["sample"],
88
  headers=headers,
89
- timeout=Config.TXT2IMG_TIMEOUT,
90
  )
91
  return Image.open(io.BytesIO(image.content))
92
 
@@ -102,7 +102,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
102
  return Image.open(io.BytesIO(bytes))
103
  else:
104
  url = response.json()["images"][0]["url"]
105
- image = httpx.get(url, headers=headers, timeout=Config.TXT2IMG_TIMEOUT)
106
  return Image.open(io.BytesIO(image.content))
107
 
108
  if service == "Hugging Face":
@@ -110,7 +110,7 @@ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
110
 
111
  if service == "Together":
112
  url = response.json()["data"][0]["url"]
113
- image = httpx.get(url, headers=headers, timeout=Config.TXT2IMG_TIMEOUT)
114
  return Image.open(io.BytesIO(image.content))
115
 
116
  else:
 
7
  from openai import APIError, OpenAI
8
  from PIL import Image
9
 
10
+ from .config import config
11
 
12
 
13
  def txt2txt_generate(api_key, service, model, parameters, **kwargs):
14
+ base_url = config.services[service]
15
  if service == "Hugging Face":
16
  base_url = f"{base_url}/{model}/v1"
17
  client = OpenAI(api_key=api_key, base_url=base_url)
 
62
  json = {**parameters, **kwargs}
63
  json["prompt"] = inputs
64
 
65
+ base_url = config.services[service]
66
 
67
  if service not in ["Together"]:
68
  base_url = f"{base_url}/{model}"
69
 
70
  try:
71
+ response = httpx.post(base_url, headers=headers, json=json, timeout=config.txt2img.timeout)
72
  if response.status_code // 100 == 2: # 2xx
73
  # BFL is async so we need to poll for result
74
  # https://api.bfl.ml/docs
75
  if service == "Black Forest Labs":
76
  id = response.json()["id"]
77
+ url = f"{config.services[service]}/get_result?id={id}"
78
 
79
  retries = 0
80
+ while retries < config.txt2img.timeout:
81
+ response = httpx.get(url, timeout=config.txt2img.timeout)
82
  if response.status_code // 100 != 2:
83
  return f"Error: {response.status_code} {response.text}"
84
 
 
86
  image = httpx.get(
87
  response.json()["result"]["sample"],
88
  headers=headers,
89
+ timeout=config.txt2img.timeout,
90
  )
91
  return Image.open(io.BytesIO(image.content))
92
 
 
102
  return Image.open(io.BytesIO(bytes))
103
  else:
104
  url = response.json()["images"][0]["url"]
105
+ image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
106
  return Image.open(io.BytesIO(image.content))
107
 
108
  if service == "Hugging Face":
 
110
 
111
  if service == "Together":
112
  url = response.json()["data"][0]["url"]
113
+ image = httpx.get(url, headers=headers, timeout=config.txt2img.timeout)
114
  return Image.open(io.BytesIO(image.content))
115
 
116
  else:
lib/config.py CHANGED
@@ -1,115 +1,149 @@
1
- from types import SimpleNamespace
 
2
 
3
- Config = SimpleNamespace(
4
- TITLE="API Inference",
5
- ICON="⚑",
6
- LAYOUT="wide",
7
- SERVICES={
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  "Black Forest Labs": "https://api.bfl.ml/v1",
9
  "Fal": "https://fal.run",
10
  "Hugging Face": "https://api-inference.huggingface.co/models",
11
  "Perplexity": "https://api.perplexity.ai",
12
  "Together": "https://api.together.xyz/v1/images/generations",
13
  },
14
- TXT2IMG_TIMEOUT=60,
15
- TXT2IMG_HIDDEN_PARAMETERS=[
16
- # sent to API but not shown in generation parameters accordion
17
- "enable_safety_checker",
18
- "max_sequence_length",
19
- "num_images",
20
- "output_format",
21
- "performance",
22
- "safety_tolerance",
23
- "scheduler",
24
- "sharpness",
25
- "style",
26
- "styles",
27
- "sync_mode",
28
- ],
29
- TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, disfigured, deformed, mutated, malformed, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
30
- TXT2IMG_DEFAULT_MODEL={
31
- # The index of model in below lists
32
- "Black Forest Labs": 2,
33
- "Fal": 0,
34
- "Hugging Face": 2,
35
- "Together": 0,
36
- },
37
- TXT2IMG_MODELS={
38
- # Model IDs referenced in Text_to_Image.py
39
- "Black Forest Labs": [
40
- "flux-dev",
41
- "flux-pro",
42
- "flux-pro-1.1",
43
- ],
44
- "Fal": [
45
- "fal-ai/aura-flow",
46
- "fal-ai/flux/dev",
47
- "fal-ai/flux/schnell",
48
- "fal-ai/flux-pro",
49
- "fal-ai/flux-pro/v1.1",
50
- "fal-ai/fooocus",
51
- "fal-ai/kolors",
52
- "fal-ai/stable-diffusion-v3-medium",
 
 
 
 
 
 
 
53
  ],
54
- "Hugging Face": [
55
- "black-forest-labs/flux.1-dev",
56
- "black-forest-labs/flux.1-schnell",
57
- "stabilityai/stable-diffusion-xl-base-1.0",
 
 
 
 
58
  ],
59
- "Together": [
60
- "black-forest-labs/FLUX.1-schnell-Free",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ],
62
- },
63
- TXT2IMG_DEFAULT_IMAGE_SIZE="square_hd", # fal image sizes
64
- TXT2IMG_IMAGE_SIZES=[
65
- "landscape_16_9",
66
- "landscape_4_3",
67
- "square_hd",
68
- "portrait_4_3",
69
- "portrait_16_9",
70
- ],
71
- TXT2IMG_DEFAULT_ASPECT_RATIO="1024x1024", # fooocus aspect ratios
72
- TXT2IMG_ASPECT_RATIOS=[
73
- "704x1408", # 1:2
74
- "704x1344", # 11:21
75
- "768x1344", # 4:7
76
- "768x1280", # 3:5
77
- "832x1216", # 13:19
78
- "832x1152", # 13:18
79
- "896x1152", # 7:9
80
- "896x1088", # 14:17
81
- "960x1088", # 15:17
82
- "960x1024", # 15:16
83
- "1024x1024",
84
- "1024x960", # 16:15
85
- "1088x960", # 17:15
86
- "1088x896", # 17:14
87
- "1152x896", # 9:7
88
- "1152x832", # 18:13
89
- "1216x832", # 19:13
90
- "1280x768", # 5:3
91
- "1344x768", # 7:4
92
- "1344x704", # 21:11
93
- "1408x704", # 2:1
94
- ],
95
- TXT2TXT_DEFAULT_SYSTEM="You are a helpful assistant. Be precise and concise.",
96
- TXT2TXT_DEFAULT_MODEL={
97
- "Hugging Face": 4,
98
- "Perplexity": 3,
99
- },
100
- TXT2TXT_MODELS={
101
- "Hugging Face": [
102
- "codellama/codellama-34b-instruct-hf",
103
- "meta-llama/llama-2-13b-chat-hf",
104
- "meta-llama/meta-llama-3.1-405b-instruct-fp8",
105
- "mistralai/mistral-7b-instruct-v0.2",
106
- "nousresearch/nous-hermes-2-mixtral-8x7b-dpo",
107
- ],
108
- "Perplexity": [
109
- "llama-3.1-sonar-small-128k-chat",
110
- "llama-3.1-sonar-large-128k-chat",
111
- "llama-3.1-sonar-small-128k-online",
112
- "llama-3.1-sonar-large-128k-online",
113
- ],
114
- },
115
  )
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List
3
 
4
+
5
+ @dataclass
6
+ class Txt2TxtConfig:
7
+ default_system: str
8
+ default_model: Dict[str, int]
9
+ models: Dict[str, List[str]]
10
+
11
+
12
+ @dataclass
13
+ class Txt2ImgConfig:
14
+ default_model: Dict[str, int]
15
+ models: Dict[str, List[str]]
16
+ hidden_parameters: List[str]
17
+ negative_prompt: str
18
+ default_image_size: str
19
+ image_sizes: List[str]
20
+ default_aspect_ratio: str
21
+ aspect_ratios: List[str]
22
+ timeout: int = 60
23
+
24
+
25
+ @dataclass
26
+ class Config:
27
+ title: str
28
+ icon: str
29
+ layout: str
30
+ services: Dict[str, str]
31
+ txt2img: Txt2ImgConfig
32
+ txt2txt: Txt2TxtConfig
33
+
34
+
35
+ config = Config(
36
+ title="API Inference",
37
+ icon="⚑",
38
+ layout="wide",
39
+ services={
40
  "Black Forest Labs": "https://api.bfl.ml/v1",
41
  "Fal": "https://fal.run",
42
  "Hugging Face": "https://api-inference.huggingface.co/models",
43
  "Perplexity": "https://api.perplexity.ai",
44
  "Together": "https://api.together.xyz/v1/images/generations",
45
  },
46
+ txt2img=Txt2ImgConfig(
47
+ default_model={
48
+ "Black Forest Labs": 2,
49
+ "Fal": 0,
50
+ "Hugging Face": 2,
51
+ "Together": 0,
52
+ },
53
+ models={
54
+ # Model identifiers referenced in Text_to_Image.py
55
+ "Black Forest Labs": [
56
+ "flux-dev",
57
+ "flux-pro",
58
+ "flux-pro-1.1",
59
+ ],
60
+ "Fal": [
61
+ "fal-ai/aura-flow",
62
+ "fal-ai/flux/dev",
63
+ "fal-ai/flux/schnell",
64
+ "fal-ai/flux-pro",
65
+ "fal-ai/flux-pro/v1.1",
66
+ "fal-ai/fooocus",
67
+ "fal-ai/kolors",
68
+ "fal-ai/stable-diffusion-v3-medium",
69
+ ],
70
+ "Hugging Face": [
71
+ "black-forest-labs/flux.1-dev",
72
+ "black-forest-labs/flux.1-schnell",
73
+ "stabilityai/stable-diffusion-xl-base-1.0",
74
+ ],
75
+ "Together": [
76
+ "black-forest-labs/FLUX.1-schnell-Free",
77
+ ],
78
+ },
79
+ hidden_parameters=[
80
+ # sent to API but not shown in generation parameters accordion
81
+ "enable_safety_checker",
82
+ "max_sequence_length",
83
+ "num_images",
84
+ "output_format",
85
+ "performance",
86
+ "safety_tolerance",
87
+ "scheduler",
88
+ "sharpness",
89
+ "style",
90
+ "styles",
91
+ "sync_mode",
92
  ],
93
+ negative_prompt="ugly, unattractive, disfigured, deformed, mutated, malformed, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
94
+ default_image_size="square_hd",
95
+ image_sizes=[
96
+ "landscape_16_9",
97
+ "landscape_4_3",
98
+ "square_hd",
99
+ "portrait_4_3",
100
+ "portrait_16_9",
101
  ],
102
+ default_aspect_ratio="1024x1024",
103
+ aspect_ratios=[
104
+ "704x1408", # 1:2
105
+ "704x1344", # 11:21
106
+ "768x1344", # 4:7
107
+ "768x1280", # 3:5
108
+ "832x1216", # 13:19
109
+ "832x1152", # 13:18
110
+ "896x1152", # 7:9
111
+ "896x1088", # 14:17
112
+ "960x1088", # 15:17
113
+ "960x1024", # 15:16
114
+ "1024x1024",
115
+ "1024x960", # 16:15
116
+ "1088x960", # 17:15
117
+ "1088x896", # 17:14
118
+ "1152x896", # 9:7
119
+ "1152x832", # 18:13
120
+ "1216x832", # 19:13
121
+ "1280x768", # 5:3
122
+ "1344x768", # 7:4
123
+ "1344x704", # 21:11
124
+ "1408x704", # 2:1
125
  ],
126
+ ),
127
+ txt2txt=Txt2TxtConfig(
128
+ default_system="You are a helpful assistant. Be precise and concise.",
129
+ default_model={
130
+ "Hugging Face": 4,
131
+ "Perplexity": 3,
132
+ },
133
+ models={
134
+ "Hugging Face": [
135
+ "codellama/codellama-34b-instruct-hf",
136
+ "meta-llama/llama-2-13b-chat-hf",
137
+ "meta-llama/meta-llama-3.1-405b-instruct-fp8",
138
+ "mistralai/mistral-7b-instruct-v0.2",
139
+ "nousresearch/nous-hermes-2-mixtral-8x7b-dpo",
140
+ ],
141
+ "Perplexity": [
142
+ "llama-3.1-sonar-small-128k-chat",
143
+ "llama-3.1-sonar-large-128k-chat",
144
+ "llama-3.1-sonar-small-128k-online",
145
+ "llama-3.1-sonar-large-128k-online",
146
+ ],
147
+ },
148
+ ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  )
pages/1_πŸ’¬_Text_Generation.py CHANGED
@@ -3,7 +3,7 @@ from datetime import datetime
3
 
4
  import streamlit as st
5
 
6
- from lib import Config, ServicePresets, txt2txt_generate
7
 
8
  SERVICE_SESSION = {
9
  "Hugging Face": "api_key_hugging_face",
@@ -17,9 +17,9 @@ SESSION_TOKEN = {
17
 
18
  # config
19
  st.set_page_config(
20
- page_title=f"{Config.TITLE} | Text Generation",
21
- page_icon=Config.ICON,
22
- layout=Config.LAYOUT,
23
  )
24
 
25
  # initialize state
@@ -61,14 +61,14 @@ for display_name, session_key in SERVICE_SESSION.items():
61
 
62
  model = st.sidebar.selectbox(
63
  "Model",
64
- options=Config.TXT2TXT_MODELS[service],
65
- index=Config.TXT2TXT_DEFAULT_MODEL[service],
66
  disabled=st.session_state.running,
67
  format_func=lambda x: x.split("/")[1] if service == "Hugging Face" else x,
68
  )
69
  system = st.sidebar.text_area(
70
  "System Message",
71
- value=Config.TXT2TXT_DEFAULT_SYSTEM,
72
  disabled=st.session_state.running,
73
  )
74
 
 
3
 
4
  import streamlit as st
5
 
6
+ from lib import ServicePresets, config, txt2txt_generate
7
 
8
  SERVICE_SESSION = {
9
  "Hugging Face": "api_key_hugging_face",
 
17
 
18
  # config
19
  st.set_page_config(
20
+ page_title=f"{config.title} | Text Generation",
21
+ page_icon=config.icon,
22
+ layout=config.layout,
23
  )
24
 
25
  # initialize state
 
61
 
62
  model = st.sidebar.selectbox(
63
  "Model",
64
+ options=config.txt2txt.models[service],
65
+ index=config.txt2txt.default_model[service],
66
  disabled=st.session_state.running,
67
  format_func=lambda x: x.split("/")[1] if service == "Hugging Face" else x,
68
  )
69
  system = st.sidebar.text_area(
70
  "System Message",
71
+ value=config.txt2txt.default_system,
72
  disabled=st.session_state.running,
73
  )
74
 
pages/2_🎨_Text_to_Image.py CHANGED
@@ -3,7 +3,7 @@ from datetime import datetime
3
 
4
  import streamlit as st
5
 
6
- from lib import Config, ModelPresets, txt2img_generate
7
 
8
  # The token name is the service in lower_snake_case
9
  SERVICE_SESSION = {
@@ -45,9 +45,9 @@ PRESET_MODEL = {
45
  }
46
 
47
  st.set_page_config(
48
- page_title=f"{Config.TITLE} | Text to Image",
49
- page_icon=Config.ICON,
50
- layout=Config.LAYOUT,
51
  )
52
 
53
  # Initialize Streamlit session state
@@ -94,8 +94,8 @@ for display_name, session_key in SERVICE_SESSION.items():
94
 
95
  model = st.sidebar.selectbox(
96
  "Model",
97
- options=Config.TXT2IMG_MODELS[service],
98
- index=Config.TXT2IMG_DEFAULT_MODEL[service],
99
  disabled=st.session_state.running,
100
  )
101
 
@@ -122,7 +122,7 @@ for param in preset["parameters"]:
122
  if param == "negative_prompt":
123
  parameters[param] = st.sidebar.text_area(
124
  "Negative Prompt",
125
- value=Config.TXT2IMG_NEGATIVE_PROMPT,
126
  disabled=st.session_state.running,
127
  )
128
  if param == "width":
@@ -146,15 +146,15 @@ for param in preset["parameters"]:
146
  if param == "image_size":
147
  parameters[param] = st.sidebar.select_slider(
148
  "Image Size",
149
- options=Config.TXT2IMG_IMAGE_SIZES,
150
- value=Config.TXT2IMG_DEFAULT_IMAGE_SIZE,
151
  disabled=st.session_state.running,
152
  )
153
  if param == "aspect_ratio":
154
  parameters[param] = st.sidebar.select_slider(
155
  "Aspect Ratio",
156
- options=Config.TXT2IMG_ASPECT_RATIOS,
157
- value=Config.TXT2IMG_DEFAULT_ASPECT_RATIO,
158
  disabled=st.session_state.running,
159
  )
160
  if param in ["guidance_scale", "guidance"]:
@@ -206,7 +206,7 @@ for message in st.session_state.txt2img_messages:
206
  filtered_parameters = [
207
  f"`{k}`: {v}"
208
  for k, v in message["parameters"].items()
209
- if k not in Config.TXT2IMG_HIDDEN_PARAMETERS
210
  ]
211
  st.markdown(f"`model`: {message['model']}\n\n" + "\n\n".join(filtered_parameters))
212
 
 
3
 
4
  import streamlit as st
5
 
6
+ from lib import ModelPresets, config, txt2img_generate
7
 
8
  # The token name is the service in lower_snake_case
9
  SERVICE_SESSION = {
 
45
  }
46
 
47
  st.set_page_config(
48
+ page_title=f"{config.title} | Text to Image",
49
+ page_icon=config.icon,
50
+ layout=config.layout,
51
  )
52
 
53
  # Initialize Streamlit session state
 
94
 
95
  model = st.sidebar.selectbox(
96
  "Model",
97
+ options=config.txt2img.models[service],
98
+ index=config.txt2img.default_model[service],
99
  disabled=st.session_state.running,
100
  )
101
 
 
122
  if param == "negative_prompt":
123
  parameters[param] = st.sidebar.text_area(
124
  "Negative Prompt",
125
+ value=config.txt2img.negative_prompt,
126
  disabled=st.session_state.running,
127
  )
128
  if param == "width":
 
146
  if param == "image_size":
147
  parameters[param] = st.sidebar.select_slider(
148
  "Image Size",
149
+ options=config.txt2img.image_sizes,
150
+ value=config.txt2img.default_image_size,
151
  disabled=st.session_state.running,
152
  )
153
  if param == "aspect_ratio":
154
  parameters[param] = st.sidebar.select_slider(
155
  "Aspect Ratio",
156
+ options=config.txt2img.aspect_ratios,
157
+ value=config.txt2img.default_aspect_ratio,
158
  disabled=st.session_state.running,
159
  )
160
  if param in ["guidance_scale", "guidance"]:
 
206
  filtered_parameters = [
207
  f"`{k}`: {v}"
208
  for k, v in message["parameters"].items()
209
+ if k not in config.txt2img.hidden_parameters
210
  ]
211
  st.markdown(f"`model`: {message['model']}\n\n" + "\n\n".join(filtered_parameters))
212