adamelliotfields
commited on
Commit
β’
8407ae2
1
Parent(s):
6f2752e
Add Anthropic API
Browse files- 0_π _Home.py +1 -0
- README.md +1 -0
- lib/api.py +17 -7
- lib/config.py +21 -1
- pages/1_π¬_Text_Generation.py +14 -2
- requirements.txt +1 -0
0_π _Home.py
CHANGED
@@ -49,6 +49,7 @@ st.page_link("pages/2_π¨_Text_to_Image.py", label="Text to Image", icon="π¨"
|
|
49 |
st.markdown("""
|
50 |
## Services
|
51 |
|
|
|
52 |
- [Black Forest Labs](https://docs.bfl.ml)
|
53 |
- [fal.ai](https://fal.ai/docs)
|
54 |
- [Hugging Face](https://huggingface.co/docs/api-inference/index)
|
|
|
49 |
st.markdown("""
|
50 |
## Services
|
51 |
|
52 |
+
- [Anthropic](https://docs.anthropic.com/en/api/getting-started)
|
53 |
- [Black Forest Labs](https://docs.bfl.ml)
|
54 |
- [fal.ai](https://fal.ai/docs)
|
55 |
- [Hugging Face](https://huggingface.co/docs/api-inference/index)
|
README.md
CHANGED
@@ -27,6 +27,7 @@ models:
|
|
27 |
Setting keys as environment variables persists them so you don't have to enter them on every page load.
|
28 |
|
29 |
```bash
|
|
|
30 |
BFL_API_KEY=...
|
31 |
FAL_KEY=...
|
32 |
HF_TOKEN=...
|
|
|
27 |
Setting keys as environment variables persists them so you don't have to enter them on every page load.
|
28 |
|
29 |
```bash
|
30 |
+
ANTHROPIC_API_KEY=...
|
31 |
BFL_API_KEY=...
|
32 |
FAL_KEY=...
|
33 |
HF_TOKEN=...
|
lib/api.py
CHANGED
@@ -4,23 +4,33 @@ import time
|
|
4 |
|
5 |
import httpx
|
6 |
import streamlit as st
|
7 |
-
from
|
|
|
|
|
|
|
8 |
from PIL import Image
|
9 |
|
10 |
from .config import config
|
11 |
|
12 |
|
13 |
-
def txt2txt_generate(api_key, service,
|
14 |
base_url = config.services[service].url
|
15 |
|
16 |
if service == "hf":
|
17 |
-
base_url = f"{base_url}/{model}/v1"
|
18 |
-
client = OpenAI(api_key=api_key, base_url=base_url)
|
19 |
|
20 |
try:
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# OpenAI uses this message for streaming errors and attaches response.error to error.body
|
25 |
# https://github.com/openai/openai-python/blob/v1.0.0/src/openai/_streaming.py#L59
|
26 |
return e.body if e.message == "An error occurred during streaming" else e.message
|
|
|
4 |
|
5 |
import httpx
|
6 |
import streamlit as st
|
7 |
+
from anthropic import Anthropic
|
8 |
+
from anthropic import APIError as AnthropicAPIError
|
9 |
+
from openai import APIError as OpenAIAPIError
|
10 |
+
from openai import OpenAI
|
11 |
from PIL import Image
|
12 |
|
13 |
from .config import config
|
14 |
|
15 |
|
16 |
+
def txt2txt_generate(api_key, service, parameters, **kwargs):
|
17 |
base_url = config.services[service].url
|
18 |
|
19 |
if service == "hf":
|
20 |
+
base_url = f"{base_url}/{parameters['model']}/v1"
|
|
|
21 |
|
22 |
try:
|
23 |
+
if service == "anthropic":
|
24 |
+
client = Anthropic(api_key=api_key)
|
25 |
+
with client.messages.stream(**parameters, **kwargs) as stream:
|
26 |
+
return st.write_stream(stream.text_stream)
|
27 |
+
else:
|
28 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
29 |
+
stream = client.chat.completions.create(stream=True, **parameters, **kwargs)
|
30 |
+
return st.write_stream(stream)
|
31 |
+
except AnthropicAPIError as e:
|
32 |
+
return e.message
|
33 |
+
except OpenAIAPIError as e:
|
34 |
# OpenAI uses this message for streaming errors and attaches response.error to error.body
|
35 |
# https://github.com/openai/openai-python/blob/v1.0.0/src/openai/_streaming.py#L59
|
36 |
return e.body if e.message == "An error occurred during streaming" else e.message
|
lib/config.py
CHANGED
@@ -97,6 +97,15 @@ class AppConfig:
|
|
97 |
services: Dict[str, ServiceConfig]
|
98 |
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
_hf_text_kwargs = {
|
101 |
"system_prompt": TEXT_SYSTEM_PROMPT,
|
102 |
"frequency_penalty": 0.0,
|
@@ -129,7 +138,7 @@ _pplx_text_kwargs = {
|
|
129 |
"max_tokens_range": (512, 4096),
|
130 |
"temperature": 1.0,
|
131 |
"temperature_range": (0.0, 2.0),
|
132 |
-
"parameters": ["max_tokens", "temperature", "frequency_penalty"
|
133 |
}
|
134 |
|
135 |
config = AppConfig(
|
@@ -153,6 +162,17 @@ config = AppConfig(
|
|
153 |
"sync_mode",
|
154 |
],
|
155 |
services={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
"bfl": ServiceConfig(
|
157 |
name="Black Forest Labs",
|
158 |
url="https://api.bfl.ml/v1",
|
|
|
97 |
services: Dict[str, ServiceConfig]
|
98 |
|
99 |
|
100 |
+
_anthropic_text_kwargs = {
|
101 |
+
"system_prompt": TEXT_SYSTEM_PROMPT,
|
102 |
+
"max_tokens": 512,
|
103 |
+
"max_tokens_range": (512, 4096),
|
104 |
+
"temperature": 0.5,
|
105 |
+
"temperature_range": (0.0, 1.0),
|
106 |
+
"parameters": ["max_tokens", "temperature"],
|
107 |
+
}
|
108 |
+
|
109 |
_hf_text_kwargs = {
|
110 |
"system_prompt": TEXT_SYSTEM_PROMPT,
|
111 |
"frequency_penalty": 0.0,
|
|
|
138 |
"max_tokens_range": (512, 4096),
|
139 |
"temperature": 1.0,
|
140 |
"temperature_range": (0.0, 2.0),
|
141 |
+
"parameters": ["max_tokens", "temperature", "frequency_penalty"],
|
142 |
}
|
143 |
|
144 |
config = AppConfig(
|
|
|
162 |
"sync_mode",
|
163 |
],
|
164 |
services={
|
165 |
+
"anthropic": ServiceConfig(
|
166 |
+
name="Anthropic",
|
167 |
+
url="https://api.anthropic.com/v1",
|
168 |
+
api_key=os.environ.get("ANTHROPIC_API_KEY"),
|
169 |
+
text={
|
170 |
+
"claude-3-haiku-20240307": TextModelConfig("Claude 3 Haiku", **_anthropic_text_kwargs),
|
171 |
+
"claude-3-opus-20240229": TextModelConfig("Claude 3 Opus", **_anthropic_text_kwargs),
|
172 |
+
"claude-3-sonnet-20240229": TextModelConfig("Claude 3 Sonnet", **_anthropic_text_kwargs),
|
173 |
+
"claude-3-5-sonnet-20240620": TextModelConfig("Claude 3.5 Sonnet", **_anthropic_text_kwargs),
|
174 |
+
},
|
175 |
+
),
|
176 |
"bfl": ServiceConfig(
|
177 |
name="Black Forest Labs",
|
178 |
url="https://api.bfl.ml/v1",
|
pages/1_π¬_Text_Generation.py
CHANGED
@@ -10,6 +10,9 @@ st.set_page_config(
|
|
10 |
layout=config.layout,
|
11 |
)
|
12 |
|
|
|
|
|
|
|
13 |
if "api_key_hf" not in st.session_state:
|
14 |
st.session_state.api_key_hf = ""
|
15 |
|
@@ -80,7 +83,7 @@ st.html("""
|
|
80 |
""")
|
81 |
|
82 |
# Build parameters from preset by rendering the appropriate input widgets
|
83 |
-
parameters = {}
|
84 |
for param in model_config.parameters:
|
85 |
if param == "max_tokens":
|
86 |
parameters[param] = st.sidebar.slider(
|
@@ -92,6 +95,7 @@ for param in model_config.parameters:
|
|
92 |
disabled=st.session_state.running,
|
93 |
help="Maximum number of tokens to generate (default: 512)",
|
94 |
)
|
|
|
95 |
if param == "temperature":
|
96 |
parameters[param] = st.sidebar.slider(
|
97 |
"Temperature",
|
@@ -102,6 +106,7 @@ for param in model_config.parameters:
|
|
102 |
disabled=st.session_state.running,
|
103 |
help="Used to modulate the next token probabilities (default: 1.0)",
|
104 |
)
|
|
|
105 |
if param == "frequency_penalty":
|
106 |
parameters[param] = st.sidebar.slider(
|
107 |
"Frequency Penalty",
|
@@ -112,6 +117,7 @@ for param in model_config.parameters:
|
|
112 |
disabled=st.session_state.running,
|
113 |
help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
|
114 |
)
|
|
|
115 |
if param == "presence_penalty":
|
116 |
parameters[param] = st.sidebar.slider(
|
117 |
"Presence Penalty",
|
@@ -122,6 +128,7 @@ for param in model_config.parameters:
|
|
122 |
disabled=st.session_state.running,
|
123 |
help="Penalize new tokens based on their presence in the text so far (default: 0.0)",
|
124 |
)
|
|
|
125 |
if param == "seed":
|
126 |
parameters[param] = st.sidebar.number_input(
|
127 |
"Seed",
|
@@ -180,7 +187,12 @@ if prompt := st.chat_input(
|
|
180 |
if button_container:
|
181 |
button_container.empty()
|
182 |
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
184 |
messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
|
185 |
messages.append({"role": "user", "content": prompt})
|
186 |
parameters["messages"] = messages
|
|
|
10 |
layout=config.layout,
|
11 |
)
|
12 |
|
13 |
+
if "api_key_anthropic" not in st.session_state:
|
14 |
+
st.session_state.api_key_anthropic = ""
|
15 |
+
|
16 |
if "api_key_hf" not in st.session_state:
|
17 |
st.session_state.api_key_hf = ""
|
18 |
|
|
|
83 |
""")
|
84 |
|
85 |
# Build parameters from preset by rendering the appropriate input widgets
|
86 |
+
parameters = {"model": model}
|
87 |
for param in model_config.parameters:
|
88 |
if param == "max_tokens":
|
89 |
parameters[param] = st.sidebar.slider(
|
|
|
95 |
disabled=st.session_state.running,
|
96 |
help="Maximum number of tokens to generate (default: 512)",
|
97 |
)
|
98 |
+
|
99 |
if param == "temperature":
|
100 |
parameters[param] = st.sidebar.slider(
|
101 |
"Temperature",
|
|
|
106 |
disabled=st.session_state.running,
|
107 |
help="Used to modulate the next token probabilities (default: 1.0)",
|
108 |
)
|
109 |
+
|
110 |
if param == "frequency_penalty":
|
111 |
parameters[param] = st.sidebar.slider(
|
112 |
"Frequency Penalty",
|
|
|
117 |
disabled=st.session_state.running,
|
118 |
help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
|
119 |
)
|
120 |
+
|
121 |
if param == "presence_penalty":
|
122 |
parameters[param] = st.sidebar.slider(
|
123 |
"Presence Penalty",
|
|
|
128 |
disabled=st.session_state.running,
|
129 |
help="Penalize new tokens based on their presence in the text so far (default: 0.0)",
|
130 |
)
|
131 |
+
|
132 |
if param == "seed":
|
133 |
parameters[param] = st.sidebar.number_input(
|
134 |
"Seed",
|
|
|
187 |
if button_container:
|
188 |
button_container.empty()
|
189 |
|
190 |
+
if service == "anthropic":
|
191 |
+
messages = []
|
192 |
+
parameters["system"] = system
|
193 |
+
else:
|
194 |
+
messages = [{"role": "system", "content": system}]
|
195 |
+
|
196 |
messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
|
197 |
messages.append({"role": "user", "content": prompt})
|
198 |
parameters["messages"] = messages
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
h2
|
2 |
httpx
|
3 |
openai==1.41.0
|
|
|
1 |
+
anthropic==0.36.0
|
2 |
h2
|
3 |
httpx
|
4 |
openai==1.41.0
|