Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,21 +1,181 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
""
|
5 |
-
|
6 |
-
""
|
7 |
-
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
message,
|
12 |
history: list[tuple[str, str]],
|
13 |
-
|
14 |
max_tokens,
|
15 |
temperature,
|
16 |
top_p,
|
17 |
):
|
18 |
-
messages = [
|
19 |
|
20 |
for val in history:
|
21 |
if val[0]:
|
@@ -25,39 +185,154 @@ def respond(
|
|
25 |
|
26 |
messages.append({"role": "user", "content": message})
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
44 |
-
"""
|
45 |
demo = gr.ChatInterface(
|
46 |
respond,
|
|
|
|
|
47 |
additional_inputs=[
|
48 |
-
gr.
|
49 |
-
gr.Slider(minimum=1, maximum=
|
50 |
-
gr.Slider(minimum=0.1, maximum=
|
51 |
-
gr.Slider(
|
52 |
-
minimum=0.1,
|
53 |
-
maximum=1.0,
|
54 |
-
value=0.95,
|
55 |
-
step=0.05,
|
56 |
-
label="Top-p (nucleus sampling)",
|
57 |
-
),
|
58 |
],
|
|
|
|
|
59 |
)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
if __name__ == "__main__":
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Request
|
2 |
+
from typing import List
|
3 |
import gradio as gr
|
4 |
+
import requests
|
5 |
+
import argparse
|
6 |
+
import aiohttp
|
7 |
+
import uvicorn
|
8 |
+
import random
|
9 |
+
import string
|
10 |
+
import json
|
11 |
+
import math
|
12 |
+
import sys
|
13 |
+
import os
|
14 |
|
15 |
+
API_BASE = "env"
|
16 |
+
api_key = os.environ['OPENAI_API_KEY']
|
17 |
+
base_url = os.environ.get('OPENAI_BASE_URL', "https://api.openai.com/v1")
|
18 |
+
def_models = '["gpt-4", "gpt-4-0125-preview", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "chatgpt-4o-latest", "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024-07-18"]'
|
19 |
|
20 |
+
def checkModels():
|
21 |
+
global base_url
|
22 |
+
if API_BASE == "env":
|
23 |
+
try:
|
24 |
+
response = requests.get(f"{base_url}/models", headers={"Authorization": f"Bearer {get_api_key()}"})
|
25 |
+
response.raise_for_status()
|
26 |
+
if not ('data' in response.json()):
|
27 |
+
base_url = "https://api.openai.com/v1"
|
28 |
+
api_key = oai_api_key
|
29 |
+
except Exception as e:
|
30 |
+
print(f"Error testing API endpoint: {e}")
|
31 |
+
else:
|
32 |
+
base_url = "https://api.openai.com/v1"
|
33 |
+
api_key = oai_api_key
|
34 |
|
35 |
+
def loadModels():
|
36 |
+
global models, modelList
|
37 |
+
models = json.loads(def_models)
|
38 |
+
models = sorted(models)
|
39 |
+
|
40 |
+
modelList = {
|
41 |
+
"object": "list",
|
42 |
+
"data": [{"id": v, "object": "model", "created": 0, "owned_by": "system"} for v in models]
|
43 |
+
}
|
44 |
+
|
45 |
+
def handleApiKeys():
|
46 |
+
global api_key
|
47 |
+
if ',' in api_key:
|
48 |
+
output = []
|
49 |
+
for key in api_key.split(','):
|
50 |
+
try:
|
51 |
+
response = requests.get(f"{base_url}/models", headers={"Authorization": f"Bearer {key}"})
|
52 |
+
response.raise_for_status()
|
53 |
+
if ('data' in response.json()):
|
54 |
+
output.append(key)
|
55 |
+
except Exception as e:
|
56 |
+
print((F"API key {key} is not valid or an actuall error happend {e}"))
|
57 |
+
if len(output)==1:
|
58 |
+
raise RuntimeError("No API key is working")
|
59 |
+
api_key = ",".join(output)
|
60 |
+
else:
|
61 |
+
try:
|
62 |
+
response = requests.get(f"{base_url}/models", headers={"Authorization": f"Bearer {api_key}"})
|
63 |
+
response.raise_for_status()
|
64 |
+
if not ('data' in response.json()):
|
65 |
+
raise RuntimeError("Current API key is not valid")
|
66 |
+
except Exception as e:
|
67 |
+
raise RuntimeError(f"Current API key is not valid or an actual error happened: {e}")
|
68 |
+
|
69 |
+
def encodeChat(messages):
|
70 |
+
output = []
|
71 |
+
for message in messages:
|
72 |
+
role = message['role']
|
73 |
+
name = f" [{message['name']}]" if 'name' in message else ''
|
74 |
+
content = message['content']
|
75 |
+
formatted_message = f"<|im_start|>{role}{name}\n{content}<|end_of_text|>"
|
76 |
+
output.append(formatted_message)
|
77 |
+
return "\n".join(output)
|
78 |
+
|
79 |
+
def get_api_key(call='api_key'):
|
80 |
+
if call == 'api_key':
|
81 |
+
key = api_key
|
82 |
+
elif call == 'oai_api_key':
|
83 |
+
key = oai_api_key
|
84 |
+
else:
|
85 |
+
key = api_key
|
86 |
+
|
87 |
+
if ',' in key:
|
88 |
+
return random.choice(key.split(','))
|
89 |
+
return key
|
90 |
+
|
91 |
+
def moderate(messages):
|
92 |
+
try:
|
93 |
+
response = requests.post(
|
94 |
+
f"{base_url}/moderations",
|
95 |
+
headers={
|
96 |
+
"Content-Type": "application/json",
|
97 |
+
"Authorization": f"Bearer {get_api_key(call='api_key')}"
|
98 |
+
},
|
99 |
+
json={"input": encodeChat(messages)}
|
100 |
+
)
|
101 |
+
response.raise_for_status()
|
102 |
+
moderation_result = response.json()
|
103 |
+
except requests.exceptions.RequestException as e:
|
104 |
+
print(f"Error during moderation request to {base_url}: {e}")
|
105 |
+
try:
|
106 |
+
response = requests.post(
|
107 |
+
"https://api.openai.com/v1/moderations",
|
108 |
+
headers={
|
109 |
+
"Content-Type": "application/json",
|
110 |
+
"Authorization": f"Bearer {get_api_key(call='oai_api_key')}"
|
111 |
+
},
|
112 |
+
json={"input": encodeChat(messages)}
|
113 |
+
)
|
114 |
+
response.raise_for_status()
|
115 |
+
moderation_result = response.json()
|
116 |
+
except requests.exceptions.RequestException as e:
|
117 |
+
print(f"Error during moderation request to fallback URL: {e}")
|
118 |
+
return False
|
119 |
+
|
120 |
+
try:
|
121 |
+
if any(result["flagged"] for result in moderation_result["results"]):
|
122 |
+
return moderation_result
|
123 |
+
except KeyError:
|
124 |
+
if moderation_result["flagged"]:
|
125 |
+
return moderation_result
|
126 |
+
|
127 |
+
return False
|
128 |
+
|
129 |
+
async def streamChat(params):
|
130 |
+
async with aiohttp.ClientSession() as session:
|
131 |
+
try:
|
132 |
+
async with session.post(f"{base_url}/chat/completions", headers={"Authorization": f"Bearer {get_api_key(call='api_key')}", "Content-Type": "application/json"}, json=params) as r:
|
133 |
+
r.raise_for_status()
|
134 |
+
async for line in r.content:
|
135 |
+
if line:
|
136 |
+
line_str = line.decode('utf-8')
|
137 |
+
if line_str.startswith("data: "):
|
138 |
+
line_str = line_str[6:].strip()
|
139 |
+
if line_str == "[DONE]":
|
140 |
+
continue
|
141 |
+
try:
|
142 |
+
message = json.loads(line_str)
|
143 |
+
yield message
|
144 |
+
except json.JSONDecodeError:
|
145 |
+
continue
|
146 |
+
except aiohttp.ClientError:
|
147 |
+
try:
|
148 |
+
async with session.post("https://api.openai.com/v1/chat/completions", headers={"Authorization": f"Bearer {get_api_key(call='oai_api_key')}", "Content-Type": "application/json"}, json=params) as r:
|
149 |
+
r.raise_for_status()
|
150 |
+
async for line in r.content:
|
151 |
+
if line:
|
152 |
+
line_str = line.decode('utf-8')
|
153 |
+
if line_str.startswith("data: "):
|
154 |
+
line_str = line_str[6:].strip()
|
155 |
+
if line_str == "[DONE]":
|
156 |
+
continue
|
157 |
+
try:
|
158 |
+
message = json.loads(line_str)
|
159 |
+
yield message
|
160 |
+
except json.JSONDecodeError:
|
161 |
+
continue
|
162 |
+
except aiohttp.ClientError:
|
163 |
+
return
|
164 |
+
|
165 |
+
def rnd(length=8):
|
166 |
+
letters = string.ascii_letters + string.digits
|
167 |
+
return ''.join(random.choice(letters) for i in range(length))
|
168 |
+
|
169 |
+
|
170 |
+
async def respond(
|
171 |
message,
|
172 |
history: list[tuple[str, str]],
|
173 |
+
model_name,
|
174 |
max_tokens,
|
175 |
temperature,
|
176 |
top_p,
|
177 |
):
|
178 |
+
messages = [];
|
179 |
|
180 |
for val in history:
|
181 |
if val[0]:
|
|
|
185 |
|
186 |
messages.append({"role": "user", "content": message})
|
187 |
|
188 |
+
if message:
|
189 |
+
mode = moderate(messages)
|
190 |
+
if mode:
|
191 |
+
reasons = []
|
192 |
+
categories = mode[0].get('categories', {}) if isinstance(mode, list) else mode.get('categories', {})
|
193 |
+
for category, flagged in categories.items():
|
194 |
+
if flagged:
|
195 |
+
reasons.append(category)
|
196 |
+
if reasons:
|
197 |
+
yield "[MODERATION] I'm sorry, but I can't assist with that.\n\nReasons:\n```\n" + "\n".join([f"{i+1}. {reason}" for i, reason in enumerate(reasons)]) + "\n```"
|
198 |
+
else:
|
199 |
+
yield "[MODERATION] I'm sorry, but I can't assist with that."
|
200 |
+
return
|
201 |
+
|
202 |
+
async def handleResponse(completion, prefix="", image_count=0, didSearchedAlready=False):
|
203 |
+
response = ""
|
204 |
+
isRequeryNeeded = False
|
205 |
+
async for token in completion:
|
206 |
+
response += token['choices'][0]['delta'].get("content", token['choices'][0]['delta'].get("refusal", ""))
|
207 |
+
yield f"{prefix}{response}"
|
208 |
+
mode = moderate([handleMultimodalData(model_name, "user", message),{"role": "assistant", "content": response}])
|
209 |
+
if mode:
|
210 |
+
reasons = []
|
211 |
+
categories = mode[0].get('categories', {}) if isinstance(mode, list) else mode.get('categories', {})
|
212 |
+
for category, flagged in categories.items():
|
213 |
+
if flagged:
|
214 |
+
reasons.append(category)
|
215 |
+
if reasons:
|
216 |
+
yield "[MODERATION] I'm sorry, but I can't assist with that.\n\nReasons:\n```\n" + "\n".join([f"{i+1}. {reason}" for i, reason in enumerate(reasons)]) + "\n```"
|
217 |
+
else:
|
218 |
+
yield "[MODERATION] I'm sorry, but I can't assist with that."
|
219 |
+
return
|
220 |
+
for line in response.split('\n'):
|
221 |
+
try:
|
222 |
+
data = json.loads(line)
|
223 |
+
if isinstance(data, dict) and data.get("tool") == "imagine" and data.get("isCall") and "prompt" in data:
|
224 |
+
if image_count < 4:
|
225 |
+
image_count += 1
|
226 |
+
def fetch_image_url(prompt, line):
|
227 |
+
image_url = imagine(prompt)
|
228 |
+
return line, f'<img src="{image_url}" alt="{prompt}" width="512"/>'
|
229 |
|
230 |
+
def replace_line_in_response(line, replacement):
|
231 |
+
nonlocal response
|
232 |
+
response = response.replace(line, replacement)
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
+
thread = threading.Thread(target=lambda: replace_line_in_response(*fetch_image_url(data["prompt"], line)))
|
235 |
+
thread.start()
|
236 |
+
thread.join()
|
237 |
+
else:
|
238 |
+
response = response.replace(line, f'[System: 4 image per message limit; prompt asked: `{data["prompt"]}]`')
|
239 |
+
yield f"{prefix}{response}"
|
240 |
+
elif isinstance(data, dict) and data.get("tool") == "calc" and data.get("isCall") and "prompt" in data:
|
241 |
+
isRequeryNeeded = True
|
242 |
+
try:
|
243 |
+
result = safe_eval(data["prompt"])
|
244 |
+
response = response.replace(line, f'[System: `{data["prompt"]}` === `{result}`]')
|
245 |
+
except Exception as e:
|
246 |
+
response = response.replace(line, f'[System: Error in calculation; `{e}`]')
|
247 |
+
yield f"{prefix}{response}"
|
248 |
+
elif isinstance(data, dict) and data.get("tool") == "search" and data.get("isCall") and "prompt" in data:
|
249 |
+
isRequeryNeeded = True
|
250 |
+
if didSearchedAlready:
|
251 |
+
response = response.replace(line, f'[System: One search per response is allowed; due to how long and resource it takes; query: `{data["prompt"]}]`]')
|
252 |
+
else:
|
253 |
+
try:
|
254 |
+
result = searchEngine(data["prompt"])
|
255 |
+
result_escaped = result.replace('`', '\\`')
|
256 |
+
response = response.replace(line, f'[System: `{data["prompt"]}` ===\n```\n{result_escaped}\n```\n]')
|
257 |
+
didSearchedAlready = True
|
258 |
+
except Exception as e:
|
259 |
+
response = response.replace(line, f'[System: Error in search function; `{e}`]')
|
260 |
+
yield f"{prefix}{response}"
|
261 |
+
yield f"{prefix}{response}"
|
262 |
+
except (json.JSONDecodeError, AttributeError, Exception):
|
263 |
+
continue
|
264 |
+
if isRequeryNeeded:
|
265 |
+
messages.append({"role": "assistant", "content": response})
|
266 |
+
async for res in handleResponse(streamChat({
|
267 |
+
"model": model_name,
|
268 |
+
"messages": messages,
|
269 |
+
"max_tokens": max_tokens,
|
270 |
+
"temperature": temperature,
|
271 |
+
"top_p": top_p,
|
272 |
+
"user": rnd(),
|
273 |
+
"stream": True
|
274 |
+
}), f"{prefix}{response}\n\n", image_count, didSearchedAlready):
|
275 |
+
yield res
|
276 |
+
async for res in handleResponse(streamChat({
|
277 |
+
"model": model_name,
|
278 |
+
"messages": messages,
|
279 |
+
"max_tokens": max_tokens,
|
280 |
+
"temperature": temperature,
|
281 |
+
"top_p": top_p,
|
282 |
+
"user": rnd(),
|
283 |
+
"stream": True
|
284 |
+
})):
|
285 |
+
yield res
|
286 |
+
|
287 |
|
288 |
+
handleApiKeys();loadModels();checkModels();
|
|
|
|
|
289 |
demo = gr.ChatInterface(
|
290 |
respond,
|
291 |
+
title="gpt-4o-mini-small",
|
292 |
+
description=f"This is the smaller version of quardo/gpt-4o-small space.<br/>Mainly exists when the main space is down.",
|
293 |
additional_inputs=[
|
294 |
+
gr.Dropdown(choices=models, value="gpt-4o-mini", label="Model"),
|
295 |
+
gr.Slider(minimum=1, maximum=4096, value=4096, step=1, label="Max new tokens"),
|
296 |
+
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature"),
|
297 |
+
gr.Slider(minimum=0.05, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
],
|
299 |
+
css="footer{display:none !important}",
|
300 |
+
head="""<script>if(!confirm("By using our application, which integrates with OpenAI's API, you acknowledge and agree to the following terms regarding the data you provide:\\n\\n1. Data Collection: This application may log the following data through the Gradio endpoint or the API endpoint: message requests (including messages, responses, model settings, and images sent along with the messages), images that were generated (including only the prompt and the image), search tool calls (including query, search results, summaries, and output responses), and moderation checks (including input and output).\\n2. Data Retention and Removal: Data is retained until further notice or until a specific request for removal is made.\\n3. Data Usage: The collected data may be used for various purposes, including but not limited to, administrative review of logs, AI training, and publication as a dataset.\\n4. Privacy: Please avoid sharing any personal information.\\n\\nBy continuing to use our application, you explicitly consent to the collection, use, and potential sharing of your data as described above. If you disagree with our data collection, usage, and sharing practices, we advise you not to use our application."))location.href="/declined";</script>"""
|
301 |
)
|
302 |
|
303 |
+
app = FastAPI()
|
304 |
+
|
305 |
+
@app.get("/declined")
|
306 |
+
def test():
|
307 |
+
return HTMLResponse(content="""
|
308 |
+
<html>
|
309 |
+
<head>
|
310 |
+
<title>Declined</title>
|
311 |
+
</head>
|
312 |
+
<body>
|
313 |
+
<p>Ok, you can go back to Hugging Face. I just didn't have any idea how to handle decline so you are redirected here.</p><br/>
|
314 |
+
<a href="/">Go back</button>
|
315 |
+
</body>
|
316 |
+
</html>
|
317 |
+
""")
|
318 |
+
|
319 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
320 |
+
|
321 |
+
class ArgParser(argparse.ArgumentParser):
|
322 |
+
def __init__(self, *args, **kwargs):
|
323 |
+
super(ArgParser, self).__init__(*args, **kwargs)
|
324 |
+
|
325 |
+
self.add_argument("-s", "--server", type=str, default="0.0.0.0")
|
326 |
+
self.add_argument("-p", "--port", type=int, default=7860)
|
327 |
+
self.add_argument("-d", "--dev", default=False, action="store_true")
|
328 |
+
|
329 |
+
self.args = self.parse_args(sys.argv[1:])
|
330 |
|
331 |
if __name__ == "__main__":
|
332 |
+
args = ArgParser().args
|
333 |
+
if args.dev:
|
334 |
+
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
|
335 |
+
else:
|
336 |
+
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
|
337 |
+
|
338 |
+
|