adamelliotfields commited on
Commit
8407ae2
β€’
1 Parent(s): 6f2752e

Add Anthropic API

Browse files
0_🏠_Home.py CHANGED
@@ -49,6 +49,7 @@ st.page_link("pages/2_🎨_Text_to_Image.py", label="Text to Image", icon="🎨"
49
  st.markdown("""
50
  ## Services
51
 
 
52
  - [Black Forest Labs](https://docs.bfl.ml)
53
  - [fal.ai](https://fal.ai/docs)
54
  - [Hugging Face](https://huggingface.co/docs/api-inference/index)
 
49
  st.markdown("""
50
  ## Services
51
 
52
+ - [Anthropic](https://docs.anthropic.com/en/api/getting-started)
53
  - [Black Forest Labs](https://docs.bfl.ml)
54
  - [fal.ai](https://fal.ai/docs)
55
  - [Hugging Face](https://huggingface.co/docs/api-inference/index)
README.md CHANGED
@@ -27,6 +27,7 @@ models:
27
  Setting keys as environment variables persists them so you don't have to enter them on every page load.
28
 
29
  ```bash
 
30
  BFL_API_KEY=...
31
  FAL_KEY=...
32
  HF_TOKEN=...
 
27
  Setting keys as environment variables persists them so you don't have to enter them on every page load.
28
 
29
  ```bash
30
+ ANTHROPIC_API_KEY=...
31
  BFL_API_KEY=...
32
  FAL_KEY=...
33
  HF_TOKEN=...
lib/api.py CHANGED
@@ -4,23 +4,33 @@ import time
4
 
5
  import httpx
6
  import streamlit as st
7
- from openai import APIError, OpenAI
 
 
 
8
  from PIL import Image
9
 
10
  from .config import config
11
 
12
 
13
- def txt2txt_generate(api_key, service, model, parameters, **kwargs):
14
  base_url = config.services[service].url
15
 
16
  if service == "hf":
17
- base_url = f"{base_url}/{model}/v1"
18
- client = OpenAI(api_key=api_key, base_url=base_url)
19
 
20
  try:
21
- stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
22
- return st.write_stream(stream)
23
- except APIError as e:
 
 
 
 
 
 
 
 
24
  # OpenAI uses this message for streaming errors and attaches response.error to error.body
25
  # https://github.com/openai/openai-python/blob/v1.0.0/src/openai/_streaming.py#L59
26
  return e.body if e.message == "An error occurred during streaming" else e.message
 
4
 
5
  import httpx
6
  import streamlit as st
7
+ from anthropic import Anthropic
8
+ from anthropic import APIError as AnthropicAPIError
9
+ from openai import APIError as OpenAIAPIError
10
+ from openai import OpenAI
11
  from PIL import Image
12
 
13
  from .config import config
14
 
15
 
16
+ def txt2txt_generate(api_key, service, parameters, **kwargs):
17
  base_url = config.services[service].url
18
 
19
  if service == "hf":
20
+ base_url = f"{base_url}/{parameters['model']}/v1"
 
21
 
22
  try:
23
+ if service == "anthropic":
24
+ client = Anthropic(api_key=api_key)
25
+ with client.messages.stream(**parameters, **kwargs) as stream:
26
+ return st.write_stream(stream.text_stream)
27
+ else:
28
+ client = OpenAI(api_key=api_key, base_url=base_url)
29
+ stream = client.chat.completions.create(stream=True, **parameters, **kwargs)
30
+ return st.write_stream(stream)
31
+ except AnthropicAPIError as e:
32
+ return e.message
33
+ except OpenAIAPIError as e:
34
  # OpenAI uses this message for streaming errors and attaches response.error to error.body
35
  # https://github.com/openai/openai-python/blob/v1.0.0/src/openai/_streaming.py#L59
36
  return e.body if e.message == "An error occurred during streaming" else e.message
lib/config.py CHANGED
@@ -97,6 +97,15 @@ class AppConfig:
97
  services: Dict[str, ServiceConfig]
98
 
99
 
 
 
 
 
 
 
 
 
 
100
  _hf_text_kwargs = {
101
  "system_prompt": TEXT_SYSTEM_PROMPT,
102
  "frequency_penalty": 0.0,
@@ -129,7 +138,7 @@ _pplx_text_kwargs = {
129
  "max_tokens_range": (512, 4096),
130
  "temperature": 1.0,
131
  "temperature_range": (0.0, 2.0),
132
- "parameters": ["max_tokens", "temperature", "frequency_penalty", "seed"],
133
  }
134
 
135
  config = AppConfig(
@@ -153,6 +162,17 @@ config = AppConfig(
153
  "sync_mode",
154
  ],
155
  services={
 
 
 
 
 
 
 
 
 
 
 
156
  "bfl": ServiceConfig(
157
  name="Black Forest Labs",
158
  url="https://api.bfl.ml/v1",
 
97
  services: Dict[str, ServiceConfig]
98
 
99
 
100
+ _anthropic_text_kwargs = {
101
+ "system_prompt": TEXT_SYSTEM_PROMPT,
102
+ "max_tokens": 512,
103
+ "max_tokens_range": (512, 4096),
104
+ "temperature": 0.5,
105
+ "temperature_range": (0.0, 1.0),
106
+ "parameters": ["max_tokens", "temperature"],
107
+ }
108
+
109
  _hf_text_kwargs = {
110
  "system_prompt": TEXT_SYSTEM_PROMPT,
111
  "frequency_penalty": 0.0,
 
138
  "max_tokens_range": (512, 4096),
139
  "temperature": 1.0,
140
  "temperature_range": (0.0, 2.0),
141
+ "parameters": ["max_tokens", "temperature", "frequency_penalty"],
142
  }
143
 
144
  config = AppConfig(
 
162
  "sync_mode",
163
  ],
164
  services={
165
+ "anthropic": ServiceConfig(
166
+ name="Anthropic",
167
+ url="https://api.anthropic.com/v1",
168
+ api_key=os.environ.get("ANTHROPIC_API_KEY"),
169
+ text={
170
+ "claude-3-haiku-20240307": TextModelConfig("Claude 3 Haiku", **_anthropic_text_kwargs),
171
+ "claude-3-opus-20240229": TextModelConfig("Claude 3 Opus", **_anthropic_text_kwargs),
172
+ "claude-3-sonnet-20240229": TextModelConfig("Claude 3 Sonnet", **_anthropic_text_kwargs),
173
+ "claude-3-5-sonnet-20240620": TextModelConfig("Claude 3.5 Sonnet", **_anthropic_text_kwargs),
174
+ },
175
+ ),
176
  "bfl": ServiceConfig(
177
  name="Black Forest Labs",
178
  url="https://api.bfl.ml/v1",
pages/1_πŸ’¬_Text_Generation.py CHANGED
@@ -10,6 +10,9 @@ st.set_page_config(
10
  layout=config.layout,
11
  )
12
 
 
 
 
13
  if "api_key_hf" not in st.session_state:
14
  st.session_state.api_key_hf = ""
15
 
@@ -80,7 +83,7 @@ st.html("""
80
  """)
81
 
82
  # Build parameters from preset by rendering the appropriate input widgets
83
- parameters = {}
84
  for param in model_config.parameters:
85
  if param == "max_tokens":
86
  parameters[param] = st.sidebar.slider(
@@ -92,6 +95,7 @@ for param in model_config.parameters:
92
  disabled=st.session_state.running,
93
  help="Maximum number of tokens to generate (default: 512)",
94
  )
 
95
  if param == "temperature":
96
  parameters[param] = st.sidebar.slider(
97
  "Temperature",
@@ -102,6 +106,7 @@ for param in model_config.parameters:
102
  disabled=st.session_state.running,
103
  help="Used to modulate the next token probabilities (default: 1.0)",
104
  )
 
105
  if param == "frequency_penalty":
106
  parameters[param] = st.sidebar.slider(
107
  "Frequency Penalty",
@@ -112,6 +117,7 @@ for param in model_config.parameters:
112
  disabled=st.session_state.running,
113
  help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
114
  )
 
115
  if param == "presence_penalty":
116
  parameters[param] = st.sidebar.slider(
117
  "Presence Penalty",
@@ -122,6 +128,7 @@ for param in model_config.parameters:
122
  disabled=st.session_state.running,
123
  help="Penalize new tokens based on their presence in the text so far (default: 0.0)",
124
  )
 
125
  if param == "seed":
126
  parameters[param] = st.sidebar.number_input(
127
  "Seed",
@@ -180,7 +187,12 @@ if prompt := st.chat_input(
180
  if button_container:
181
  button_container.empty()
182
 
183
- messages = [{"role": "system", "content": system}]
 
 
 
 
 
184
  messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
185
  messages.append({"role": "user", "content": prompt})
186
  parameters["messages"] = messages
 
10
  layout=config.layout,
11
  )
12
 
13
+ if "api_key_anthropic" not in st.session_state:
14
+ st.session_state.api_key_anthropic = ""
15
+
16
  if "api_key_hf" not in st.session_state:
17
  st.session_state.api_key_hf = ""
18
 
 
83
  """)
84
 
85
  # Build parameters from preset by rendering the appropriate input widgets
86
+ parameters = {"model": model}
87
  for param in model_config.parameters:
88
  if param == "max_tokens":
89
  parameters[param] = st.sidebar.slider(
 
95
  disabled=st.session_state.running,
96
  help="Maximum number of tokens to generate (default: 512)",
97
  )
98
+
99
  if param == "temperature":
100
  parameters[param] = st.sidebar.slider(
101
  "Temperature",
 
106
  disabled=st.session_state.running,
107
  help="Used to modulate the next token probabilities (default: 1.0)",
108
  )
109
+
110
  if param == "frequency_penalty":
111
  parameters[param] = st.sidebar.slider(
112
  "Frequency Penalty",
 
117
  disabled=st.session_state.running,
118
  help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
119
  )
120
+
121
  if param == "presence_penalty":
122
  parameters[param] = st.sidebar.slider(
123
  "Presence Penalty",
 
128
  disabled=st.session_state.running,
129
  help="Penalize new tokens based on their presence in the text so far (default: 0.0)",
130
  )
131
+
132
  if param == "seed":
133
  parameters[param] = st.sidebar.number_input(
134
  "Seed",
 
187
  if button_container:
188
  button_container.empty()
189
 
190
+ if service == "anthropic":
191
+ messages = []
192
+ parameters["system"] = system
193
+ else:
194
+ messages = [{"role": "system", "content": system}]
195
+
196
  messages.extend([{"role": m["role"], "content": m["content"]} for m in st.session_state.txt2txt_messages])
197
  messages.append({"role": "user", "content": prompt})
198
  parameters["messages"] = messages
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  h2
2
  httpx
3
  openai==1.41.0
 
1
+ anthropic==0.36.0
2
  h2
3
  httpx
4
  openai==1.41.0