Quardo commited on
Commit
00ed5fb
1 Parent(s): 2741f4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +308 -33
app.py CHANGED
@@ -1,21 +1,181 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  message,
12
  history: list[tuple[str, str]],
13
- system_message,
14
  max_tokens,
15
  temperature,
16
  top_p,
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
  for val in history:
21
  if val[0]:
@@ -25,39 +185,154 @@ def respond(
25
 
26
  messages.append({"role": "user", "content": message})
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
  demo = gr.ChatInterface(
46
  respond,
 
 
47
  additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
  ],
 
 
59
  )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from typing import List
3
  import gradio as gr
4
+ import requests
5
+ import argparse
6
+ import aiohttp
7
+ import uvicorn
8
+ import random
9
+ import string
10
+ import json
11
+ import math
12
+ import sys
13
+ import os
14
 
15
+ API_BASE = "env"
16
+ api_key = os.environ['OPENAI_API_KEY']
17
+ base_url = os.environ.get('OPENAI_BASE_URL', "https://api.openai.com/v1")
18
+ def_models = '["gpt-4", "gpt-4-0125-preview", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "chatgpt-4o-latest", "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024-07-18"]'
19
 
20
+ def checkModels():
21
+ global base_url
22
+ if API_BASE == "env":
23
+ try:
24
+ response = requests.get(f"{base_url}/models", headers={"Authorization": f"Bearer {get_api_key()}"})
25
+ response.raise_for_status()
26
+ if not ('data' in response.json()):
27
+ base_url = "https://api.openai.com/v1"
28
+ api_key = oai_api_key
29
+ except Exception as e:
30
+ print(f"Error testing API endpoint: {e}")
31
+ else:
32
+ base_url = "https://api.openai.com/v1"
33
+ api_key = oai_api_key
34
 
35
+ def loadModels():
36
+ global models, modelList
37
+ models = json.loads(def_models)
38
+ models = sorted(models)
39
+
40
+ modelList = {
41
+ "object": "list",
42
+ "data": [{"id": v, "object": "model", "created": 0, "owned_by": "system"} for v in models]
43
+ }
44
+
45
+ def handleApiKeys():
46
+ global api_key
47
+ if ',' in api_key:
48
+ output = []
49
+ for key in api_key.split(','):
50
+ try:
51
+ response = requests.get(f"{base_url}/models", headers={"Authorization": f"Bearer {key}"})
52
+ response.raise_for_status()
53
+ if ('data' in response.json()):
54
+ output.append(key)
55
+ except Exception as e:
56
+ print((F"API key {key} is not valid or an actuall error happend {e}"))
57
+ if len(output)==1:
58
+ raise RuntimeError("No API key is working")
59
+ api_key = ",".join(output)
60
+ else:
61
+ try:
62
+ response = requests.get(f"{base_url}/models", headers={"Authorization": f"Bearer {api_key}"})
63
+ response.raise_for_status()
64
+ if not ('data' in response.json()):
65
+ raise RuntimeError("Current API key is not valid")
66
+ except Exception as e:
67
+ raise RuntimeError(f"Current API key is not valid or an actual error happened: {e}")
68
+
69
+ def encodeChat(messages):
70
+ output = []
71
+ for message in messages:
72
+ role = message['role']
73
+ name = f" [{message['name']}]" if 'name' in message else ''
74
+ content = message['content']
75
+ formatted_message = f"<|im_start|>{role}{name}\n{content}<|end_of_text|>"
76
+ output.append(formatted_message)
77
+ return "\n".join(output)
78
+
79
+ def get_api_key(call='api_key'):
80
+ if call == 'api_key':
81
+ key = api_key
82
+ elif call == 'oai_api_key':
83
+ key = oai_api_key
84
+ else:
85
+ key = api_key
86
+
87
+ if ',' in key:
88
+ return random.choice(key.split(','))
89
+ return key
90
+
91
+ def moderate(messages):
92
+ try:
93
+ response = requests.post(
94
+ f"{base_url}/moderations",
95
+ headers={
96
+ "Content-Type": "application/json",
97
+ "Authorization": f"Bearer {get_api_key(call='api_key')}"
98
+ },
99
+ json={"input": encodeChat(messages)}
100
+ )
101
+ response.raise_for_status()
102
+ moderation_result = response.json()
103
+ except requests.exceptions.RequestException as e:
104
+ print(f"Error during moderation request to {base_url}: {e}")
105
+ try:
106
+ response = requests.post(
107
+ "https://api.openai.com/v1/moderations",
108
+ headers={
109
+ "Content-Type": "application/json",
110
+ "Authorization": f"Bearer {get_api_key(call='oai_api_key')}"
111
+ },
112
+ json={"input": encodeChat(messages)}
113
+ )
114
+ response.raise_for_status()
115
+ moderation_result = response.json()
116
+ except requests.exceptions.RequestException as e:
117
+ print(f"Error during moderation request to fallback URL: {e}")
118
+ return False
119
+
120
+ try:
121
+ if any(result["flagged"] for result in moderation_result["results"]):
122
+ return moderation_result
123
+ except KeyError:
124
+ if moderation_result["flagged"]:
125
+ return moderation_result
126
+
127
+ return False
128
+
129
+ async def streamChat(params):
130
+ async with aiohttp.ClientSession() as session:
131
+ try:
132
+ async with session.post(f"{base_url}/chat/completions", headers={"Authorization": f"Bearer {get_api_key(call='api_key')}", "Content-Type": "application/json"}, json=params) as r:
133
+ r.raise_for_status()
134
+ async for line in r.content:
135
+ if line:
136
+ line_str = line.decode('utf-8')
137
+ if line_str.startswith("data: "):
138
+ line_str = line_str[6:].strip()
139
+ if line_str == "[DONE]":
140
+ continue
141
+ try:
142
+ message = json.loads(line_str)
143
+ yield message
144
+ except json.JSONDecodeError:
145
+ continue
146
+ except aiohttp.ClientError:
147
+ try:
148
+ async with session.post("https://api.openai.com/v1/chat/completions", headers={"Authorization": f"Bearer {get_api_key(call='oai_api_key')}", "Content-Type": "application/json"}, json=params) as r:
149
+ r.raise_for_status()
150
+ async for line in r.content:
151
+ if line:
152
+ line_str = line.decode('utf-8')
153
+ if line_str.startswith("data: "):
154
+ line_str = line_str[6:].strip()
155
+ if line_str == "[DONE]":
156
+ continue
157
+ try:
158
+ message = json.loads(line_str)
159
+ yield message
160
+ except json.JSONDecodeError:
161
+ continue
162
+ except aiohttp.ClientError:
163
+ return
164
+
165
+ def rnd(length=8):
166
+ letters = string.ascii_letters + string.digits
167
+ return ''.join(random.choice(letters) for i in range(length))
168
+
169
+
170
+ async def respond(
171
  message,
172
  history: list[tuple[str, str]],
173
+ model_name,
174
  max_tokens,
175
  temperature,
176
  top_p,
177
  ):
178
+ messages = [];
179
 
180
  for val in history:
181
  if val[0]:
 
185
 
186
  messages.append({"role": "user", "content": message})
187
 
188
+ if message:
189
+ mode = moderate(messages)
190
+ if mode:
191
+ reasons = []
192
+ categories = mode[0].get('categories', {}) if isinstance(mode, list) else mode.get('categories', {})
193
+ for category, flagged in categories.items():
194
+ if flagged:
195
+ reasons.append(category)
196
+ if reasons:
197
+ yield "[MODERATION] I'm sorry, but I can't assist with that.\n\nReasons:\n```\n" + "\n".join([f"{i+1}. {reason}" for i, reason in enumerate(reasons)]) + "\n```"
198
+ else:
199
+ yield "[MODERATION] I'm sorry, but I can't assist with that."
200
+ return
201
+
202
+ async def handleResponse(completion, prefix="", image_count=0, didSearchedAlready=False):
203
+ response = ""
204
+ isRequeryNeeded = False
205
+ async for token in completion:
206
+ response += token['choices'][0]['delta'].get("content", token['choices'][0]['delta'].get("refusal", ""))
207
+ yield f"{prefix}{response}"
208
+ mode = moderate([handleMultimodalData(model_name, "user", message),{"role": "assistant", "content": response}])
209
+ if mode:
210
+ reasons = []
211
+ categories = mode[0].get('categories', {}) if isinstance(mode, list) else mode.get('categories', {})
212
+ for category, flagged in categories.items():
213
+ if flagged:
214
+ reasons.append(category)
215
+ if reasons:
216
+ yield "[MODERATION] I'm sorry, but I can't assist with that.\n\nReasons:\n```\n" + "\n".join([f"{i+1}. {reason}" for i, reason in enumerate(reasons)]) + "\n```"
217
+ else:
218
+ yield "[MODERATION] I'm sorry, but I can't assist with that."
219
+ return
220
+ for line in response.split('\n'):
221
+ try:
222
+ data = json.loads(line)
223
+ if isinstance(data, dict) and data.get("tool") == "imagine" and data.get("isCall") and "prompt" in data:
224
+ if image_count < 4:
225
+ image_count += 1
226
+ def fetch_image_url(prompt, line):
227
+ image_url = imagine(prompt)
228
+ return line, f'<img src="{image_url}" alt="{prompt}" width="512"/>'
229
 
230
+ def replace_line_in_response(line, replacement):
231
+ nonlocal response
232
+ response = response.replace(line, replacement)
 
 
 
 
 
233
 
234
+ thread = threading.Thread(target=lambda: replace_line_in_response(*fetch_image_url(data["prompt"], line)))
235
+ thread.start()
236
+ thread.join()
237
+ else:
238
+ response = response.replace(line, f'[System: 4 image per message limit; prompt asked: `{data["prompt"]}]`')
239
+ yield f"{prefix}{response}"
240
+ elif isinstance(data, dict) and data.get("tool") == "calc" and data.get("isCall") and "prompt" in data:
241
+ isRequeryNeeded = True
242
+ try:
243
+ result = safe_eval(data["prompt"])
244
+ response = response.replace(line, f'[System: `{data["prompt"]}` === `{result}`]')
245
+ except Exception as e:
246
+ response = response.replace(line, f'[System: Error in calculation; `{e}`]')
247
+ yield f"{prefix}{response}"
248
+ elif isinstance(data, dict) and data.get("tool") == "search" and data.get("isCall") and "prompt" in data:
249
+ isRequeryNeeded = True
250
+ if didSearchedAlready:
251
+ response = response.replace(line, f'[System: One search per response is allowed; due to how long and resource it takes; query: `{data["prompt"]}]`]')
252
+ else:
253
+ try:
254
+ result = searchEngine(data["prompt"])
255
+ result_escaped = result.replace('`', '\\`')
256
+ response = response.replace(line, f'[System: `{data["prompt"]}` ===\n```\n{result_escaped}\n```\n]')
257
+ didSearchedAlready = True
258
+ except Exception as e:
259
+ response = response.replace(line, f'[System: Error in search function; `{e}`]')
260
+ yield f"{prefix}{response}"
261
+ yield f"{prefix}{response}"
262
+ except (json.JSONDecodeError, AttributeError, Exception):
263
+ continue
264
+ if isRequeryNeeded:
265
+ messages.append({"role": "assistant", "content": response})
266
+ async for res in handleResponse(streamChat({
267
+ "model": model_name,
268
+ "messages": messages,
269
+ "max_tokens": max_tokens,
270
+ "temperature": temperature,
271
+ "top_p": top_p,
272
+ "user": rnd(),
273
+ "stream": True
274
+ }), f"{prefix}{response}\n\n", image_count, didSearchedAlready):
275
+ yield res
276
+ async for res in handleResponse(streamChat({
277
+ "model": model_name,
278
+ "messages": messages,
279
+ "max_tokens": max_tokens,
280
+ "temperature": temperature,
281
+ "top_p": top_p,
282
+ "user": rnd(),
283
+ "stream": True
284
+ })):
285
+ yield res
286
+
287
 
288
+ handleApiKeys();loadModels();checkModels();
 
 
289
  demo = gr.ChatInterface(
290
  respond,
291
+ title="gpt-4o-mini-small",
292
+ description=f"This is the smaller version of quardo/gpt-4o-small space.<br/>Mainly exists when the main space is down.",
293
  additional_inputs=[
294
+ gr.Dropdown(choices=models, value="gpt-4o-mini", label="Model"),
295
+ gr.Slider(minimum=1, maximum=4096, value=4096, step=1, label="Max new tokens"),
296
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature"),
297
+ gr.Slider(minimum=0.05, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
298
  ],
299
+ css="footer{display:none !important}",
300
+ head="""<script>if(!confirm("By using our application, which integrates with OpenAI's API, you acknowledge and agree to the following terms regarding the data you provide:\\n\\n1. Data Collection: This application may log the following data through the Gradio endpoint or the API endpoint: message requests (including messages, responses, model settings, and images sent along with the messages), images that were generated (including only the prompt and the image), search tool calls (including query, search results, summaries, and output responses), and moderation checks (including input and output).\\n2. Data Retention and Removal: Data is retained until further notice or until a specific request for removal is made.\\n3. Data Usage: The collected data may be used for various purposes, including but not limited to, administrative review of logs, AI training, and publication as a dataset.\\n4. Privacy: Please avoid sharing any personal information.\\n\\nBy continuing to use our application, you explicitly consent to the collection, use, and potential sharing of your data as described above. If you disagree with our data collection, usage, and sharing practices, we advise you not to use our application."))location.href="/declined";</script>"""
301
  )
302
 
303
+ app = FastAPI()
304
+
305
+ @app.get("/declined")
306
+ def test():
307
+ return HTMLResponse(content="""
308
+ <html>
309
+ <head>
310
+ <title>Declined</title>
311
+ </head>
312
+ <body>
313
+ <p>Ok, you can go back to Hugging Face. I just didn't have any idea how to handle decline so you are redirected here.</p><br/>
314
+ <a href="/">Go back</button>
315
+ </body>
316
+ </html>
317
+ """)
318
+
319
+ app = gr.mount_gradio_app(app, demo, path="/")
320
+
321
+ class ArgParser(argparse.ArgumentParser):
322
+ def __init__(self, *args, **kwargs):
323
+ super(ArgParser, self).__init__(*args, **kwargs)
324
+
325
+ self.add_argument("-s", "--server", type=str, default="0.0.0.0")
326
+ self.add_argument("-p", "--port", type=int, default=7860)
327
+ self.add_argument("-d", "--dev", default=False, action="store_true")
328
+
329
+ self.args = self.parse_args(sys.argv[1:])
330
 
331
  if __name__ == "__main__":
332
+ args = ArgParser().args
333
+ if args.dev:
334
+ uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
335
+ else:
336
+ uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
337
+
338
+