Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,6 @@ import gradio as gr
|
|
7 |
import asyncio
|
8 |
from collections import OrderedDict
|
9 |
from datetime import datetime
|
10 |
-
from transformers import AutoProcessor, AutoTokenizer
|
11 |
import requests
|
12 |
from openai import OpenAI
|
13 |
from telethon import TelegramClient, events
|
@@ -16,14 +15,12 @@ from huggingface_hub import InferenceClient
|
|
16 |
import pymongo
|
17 |
from pymongo import MongoClient
|
18 |
|
19 |
-
# Load system prompt from a file
|
20 |
def load_system_prompt():
|
21 |
with open('prompt.txt', 'r') as file:
|
22 |
return file.read()
|
23 |
|
24 |
system_prompt = load_system_prompt()
|
25 |
|
26 |
-
# Environment variables
|
27 |
api_id = os.getenv('api_id')
|
28 |
api_hash = os.getenv('api_hash')
|
29 |
bot_token = os.getenv('bot_token')
|
@@ -35,24 +32,22 @@ model1 = os.getenv('model1')
|
|
35 |
model2 = os.getenv('model2')
|
36 |
mongoURI = os.getenv('MONGO_URI')
|
37 |
|
38 |
-
# Initialize OpenAI and MongoDB clients
|
39 |
openai_client = OpenAI(api_key=openai_api_key, base_url=api_url)
|
40 |
mongo_client = MongoClient(mongoURI)
|
41 |
db = mongo_client['Scarlett']
|
42 |
chat_collection = db['chats']
|
43 |
|
44 |
-
|
|
|
|
|
45 |
|
46 |
-
# Local cache for up to 5 users
|
47 |
local_chat_history = OrderedDict()
|
48 |
MAX_LOCAL_USERS = 5
|
49 |
|
50 |
-
# Retrieve chat history from MongoDB
|
51 |
def get_history_from_mongo(user_id):
|
52 |
result = chat_collection.find_one({"user_id": user_id})
|
53 |
return result.get("messages", []) if result else []
|
54 |
|
55 |
-
# Store message in MongoDB (limit to last 99 messages)
|
56 |
def store_message_in_mongo(user_id, role, content):
|
57 |
chat_collection.update_one(
|
58 |
{"user_id": user_id},
|
@@ -67,35 +62,25 @@ def store_message_in_mongo(user_id, role, content):
|
|
67 |
upsert=True
|
68 |
)
|
69 |
|
70 |
-
# Get chat history from local cache or MongoDB
|
71 |
def get_chat_history(user_id):
|
72 |
if user_id in local_chat_history:
|
73 |
-
local_chat_history.move_to_end(user_id)
|
74 |
return local_chat_history[user_id]
|
75 |
-
|
76 |
-
# Load from MongoDB if not in local cache
|
77 |
history = get_history_from_mongo(user_id)
|
78 |
local_chat_history[user_id] = history
|
79 |
-
|
80 |
if len(local_chat_history) > MAX_LOCAL_USERS:
|
81 |
-
local_chat_history.popitem(last=False)
|
82 |
-
|
83 |
return history
|
84 |
|
85 |
-
# Update chat history (both local and MongoDB)
|
86 |
def update_chat_history(user_id, role, content):
|
87 |
if user_id not in local_chat_history:
|
88 |
local_chat_history[user_id] = get_history_from_mongo(user_id)
|
89 |
-
|
90 |
local_chat_history[user_id].append({"role": role, "content": content})
|
91 |
local_chat_history.move_to_end(user_id)
|
92 |
-
|
93 |
if len(local_chat_history) > MAX_LOCAL_USERS:
|
94 |
local_chat_history.popitem(last=False)
|
95 |
-
|
96 |
store_message_in_mongo(user_id, role, content)
|
97 |
|
98 |
-
# Encode image to base64
|
99 |
def encode_local_image(image_path):
|
100 |
im = Image.fromarray(image_path)
|
101 |
buffered = BytesIO()
|
@@ -104,18 +89,13 @@ def encode_local_image(image_path):
|
|
104 |
image_base64 = base64.b64encode(image_bytes).decode('ascii')
|
105 |
return image_base64
|
106 |
|
107 |
-
# Describe image using the model with error handling
|
108 |
def inference_calling_idefics(image_path, question=""):
|
109 |
system_prompt = os.getenv('USER_PROMPT')
|
110 |
-
model_id =
|
111 |
client = InferenceClient(model=model_id)
|
112 |
-
image_base64 =
|
113 |
image_info = f"data:image/png;base64,{image_base64}"
|
114 |
-
|
115 |
-
# Include the system prompt before the user question
|
116 |
prompt = f"{system_prompt}\n\n{question}\n\n"
|
117 |
-
|
118 |
-
# Adding do_sample=True to ensure different responses and max_tokens
|
119 |
response = client.text_generation(
|
120 |
prompt,
|
121 |
max_new_tokens=512,
|
@@ -124,25 +104,20 @@ def inference_calling_idefics(image_path, question=""):
|
|
124 |
)
|
125 |
return response
|
126 |
|
127 |
-
# Function to generate answers using the idefics model with system prompt and max_tokens
|
128 |
def describe_image(image_path, question=""):
|
129 |
try:
|
130 |
answer = inference_calling_idefics(image_path, question)
|
131 |
return answer
|
132 |
except Exception as e:
|
133 |
print(e)
|
134 |
-
|
135 |
-
return answer
|
136 |
|
137 |
-
# Telegram bot client
|
138 |
client = TelegramClient('bot', api_id, api_hash).start(bot_token=bot_token)
|
139 |
|
140 |
-
# Get bot user ID for message filtering
|
141 |
async def get_bot_id():
|
142 |
me = await client.get_me()
|
143 |
return me.id
|
144 |
|
145 |
-
# Async function to get OpenAI completion
|
146 |
async def get_completion(event, user_id, prompt):
|
147 |
async with client.action(event.chat_id, 'typing'):
|
148 |
await asyncio.sleep(3)
|
@@ -152,9 +127,8 @@ async def get_completion(event, user_id, prompt):
|
|
152 |
*history,
|
153 |
{"role": "user", "content": prompt},
|
154 |
]
|
155 |
-
|
156 |
try:
|
157 |
-
|
158 |
model=model,
|
159 |
messages=messages,
|
160 |
max_tokens=512,
|
@@ -162,17 +136,18 @@ async def get_completion(event, user_id, prompt):
|
|
162 |
top_p=1.0,
|
163 |
frequency_penalty=0.9,
|
164 |
presence_penalty=0.9,
|
|
|
165 |
)
|
166 |
-
|
|
|
|
|
167 |
except Exception as e:
|
168 |
message = f"Whoops!"
|
169 |
print(e)
|
170 |
-
|
171 |
-
update_chat_history(user_id, "
|
172 |
-
update_chat_history(user_id, "assistant", message) # Update assistant's response
|
173 |
return message
|
174 |
|
175 |
-
# Telegram bot events
|
176 |
@client.on(events.NewMessage(pattern='/start'))
|
177 |
async def start(event):
|
178 |
await event.respond("Hello!")
|
@@ -184,34 +159,27 @@ async def help(event):
|
|
184 |
@client.on(events.NewMessage(pattern='/reset'))
|
185 |
async def reset(event):
|
186 |
user_id = event.chat_id
|
187 |
-
chat_collection.delete_one({"user_id": user_id})
|
188 |
if user_id in local_chat_history:
|
189 |
-
del local_chat_history[user_id]
|
190 |
await event.respond("History has been reset.")
|
191 |
|
192 |
@client.on(events.NewMessage)
|
193 |
async def handle_message(event):
|
194 |
-
|
195 |
-
|
|
|
196 |
try:
|
197 |
-
user_id = event.chat_id
|
198 |
-
|
199 |
-
# Ignore messages from the bot itself
|
200 |
if event.sender_id == bot_id:
|
201 |
return
|
202 |
-
|
203 |
user_message = event.raw_text
|
204 |
-
|
205 |
if event.photo:
|
206 |
-
# If an image is sent, describe the image
|
207 |
photo = await event.download_media()
|
208 |
image_description = describe_image(photo, user_message)
|
209 |
user_message += f"\n\nI sent you an image. Content of the image: {image_description}"
|
210 |
-
|
211 |
-
# Ignore command messages to prevent double processing
|
212 |
if user_message.startswith('/start') or user_message.startswith('/help') or user_message.startswith('/reset'):
|
213 |
return
|
214 |
-
|
215 |
response = await get_completion(event, user_id, user_message)
|
216 |
await event.respond(response)
|
217 |
except Exception as e:
|
@@ -232,14 +200,13 @@ def launch_gradio():
|
|
232 |
""")
|
233 |
demo.launch(show_api=False)
|
234 |
|
235 |
-
# Keep-alive function to keep the bot running
|
236 |
def keep_alive():
|
237 |
ping_client = OpenAI(api_key=ping_key, base_url=api_url)
|
238 |
while True:
|
239 |
try:
|
240 |
messages = [
|
241 |
-
{"role": "system", "content": "
|
242 |
-
{"role": "user", "content": "
|
243 |
]
|
244 |
request = ping_client.chat.completions.create(
|
245 |
model=model,
|
@@ -248,10 +215,10 @@ def keep_alive():
|
|
248 |
temperature=0.1,
|
249 |
top_p=0.1,
|
250 |
)
|
251 |
-
print(request)
|
252 |
except Exception as e:
|
253 |
print(f"Keep-alive request failed: {e}")
|
254 |
-
time.sleep(1800)
|
255 |
|
256 |
if __name__ == "__main__":
|
257 |
threading.Thread(target=keep_alive).start()
|
|
|
7 |
import asyncio
|
8 |
from collections import OrderedDict
|
9 |
from datetime import datetime
|
|
|
10 |
import requests
|
11 |
from openai import OpenAI
|
12 |
from telethon import TelegramClient, events
|
|
|
15 |
import pymongo
|
16 |
from pymongo import MongoClient
|
17 |
|
|
|
18 |
def load_system_prompt():
|
19 |
with open('prompt.txt', 'r') as file:
|
20 |
return file.read()
|
21 |
|
22 |
system_prompt = load_system_prompt()
|
23 |
|
|
|
24 |
api_id = os.getenv('api_id')
|
25 |
api_hash = os.getenv('api_hash')
|
26 |
bot_token = os.getenv('bot_token')
|
|
|
32 |
model2 = os.getenv('model2')
|
33 |
mongoURI = os.getenv('MONGO_URI')
|
34 |
|
|
|
35 |
openai_client = OpenAI(api_key=openai_api_key, base_url=api_url)
|
36 |
mongo_client = MongoClient(mongoURI)
|
37 |
db = mongo_client['Scarlett']
|
38 |
chat_collection = db['chats']
|
39 |
|
40 |
+
idefics_processor = AutoProcessor.from_pretrained(model1)
|
41 |
+
idefics_client = InferenceClient(model2)
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained(model1)
|
43 |
|
|
|
44 |
local_chat_history = OrderedDict()
|
45 |
MAX_LOCAL_USERS = 5
|
46 |
|
|
|
47 |
def get_history_from_mongo(user_id):
|
48 |
result = chat_collection.find_one({"user_id": user_id})
|
49 |
return result.get("messages", []) if result else []
|
50 |
|
|
|
51 |
def store_message_in_mongo(user_id, role, content):
|
52 |
chat_collection.update_one(
|
53 |
{"user_id": user_id},
|
|
|
62 |
upsert=True
|
63 |
)
|
64 |
|
|
|
65 |
def get_chat_history(user_id):
|
66 |
if user_id in local_chat_history:
|
67 |
+
local_chat_history.move_to_end(user_id)
|
68 |
return local_chat_history[user_id]
|
|
|
|
|
69 |
history = get_history_from_mongo(user_id)
|
70 |
local_chat_history[user_id] = history
|
|
|
71 |
if len(local_chat_history) > MAX_LOCAL_USERS:
|
72 |
+
local_chat_history.popitem(last=False)
|
|
|
73 |
return history
|
74 |
|
|
|
75 |
def update_chat_history(user_id, role, content):
|
76 |
if user_id not in local_chat_history:
|
77 |
local_chat_history[user_id] = get_history_from_mongo(user_id)
|
|
|
78 |
local_chat_history[user_id].append({"role": role, "content": content})
|
79 |
local_chat_history.move_to_end(user_id)
|
|
|
80 |
if len(local_chat_history) > MAX_LOCAL_USERS:
|
81 |
local_chat_history.popitem(last=False)
|
|
|
82 |
store_message_in_mongo(user_id, role, content)
|
83 |
|
|
|
84 |
def encode_local_image(image_path):
|
85 |
im = Image.fromarray(image_path)
|
86 |
buffered = BytesIO()
|
|
|
89 |
image_base64 = base64.b64encode(image_bytes).decode('ascii')
|
90 |
return image_base64
|
91 |
|
|
|
92 |
def inference_calling_idefics(image_path, question=""):
|
93 |
system_prompt = os.getenv('USER_PROMPT')
|
94 |
+
model_id = model2
|
95 |
client = InferenceClient(model=model_id)
|
96 |
+
image_base64 = describe_image(image_path)
|
97 |
image_info = f"data:image/png;base64,{image_base64}"
|
|
|
|
|
98 |
prompt = f"{system_prompt}\n\n{question}\n\n"
|
|
|
|
|
99 |
response = client.text_generation(
|
100 |
prompt,
|
101 |
max_new_tokens=512,
|
|
|
104 |
)
|
105 |
return response
|
106 |
|
|
|
107 |
def describe_image(image_path, question=""):
|
108 |
try:
|
109 |
answer = inference_calling_idefics(image_path, question)
|
110 |
return answer
|
111 |
except Exception as e:
|
112 |
print(e)
|
113 |
+
return "Error while seeing the image."
|
|
|
114 |
|
|
|
115 |
client = TelegramClient('bot', api_id, api_hash).start(bot_token=bot_token)
|
116 |
|
|
|
117 |
async def get_bot_id():
|
118 |
me = await client.get_me()
|
119 |
return me.id
|
120 |
|
|
|
121 |
async def get_completion(event, user_id, prompt):
|
122 |
async with client.action(event.chat_id, 'typing'):
|
123 |
await asyncio.sleep(3)
|
|
|
127 |
*history,
|
128 |
{"role": "user", "content": prompt},
|
129 |
]
|
|
|
130 |
try:
|
131 |
+
completion = openai_client.chat.completions.create(
|
132 |
model=model,
|
133 |
messages=messages,
|
134 |
max_tokens=512,
|
|
|
136 |
top_p=1.0,
|
137 |
frequency_penalty=0.9,
|
138 |
presence_penalty=0.9,
|
139 |
+
stream=True
|
140 |
)
|
141 |
+
for chunk in completion:
|
142 |
+
if chunk.choices[0].delta.content is not None:
|
143 |
+
message += chunk.choices[0].delta.content
|
144 |
except Exception as e:
|
145 |
message = f"Whoops!"
|
146 |
print(e)
|
147 |
+
update_chat_history(user_id, "user", prompt)
|
148 |
+
update_chat_history(user_id, "assistant", message)
|
|
|
149 |
return message
|
150 |
|
|
|
151 |
@client.on(events.NewMessage(pattern='/start'))
|
152 |
async def start(event):
|
153 |
await event.respond("Hello!")
|
|
|
159 |
@client.on(events.NewMessage(pattern='/reset'))
|
160 |
async def reset(event):
|
161 |
user_id = event.chat_id
|
162 |
+
chat_collection.delete_one({"user_id": user_id})
|
163 |
if user_id in local_chat_history:
|
164 |
+
del local_chat_history[user_id]
|
165 |
await event.respond("History has been reset.")
|
166 |
|
167 |
@client.on(events.NewMessage)
|
168 |
async def handle_message(event):
|
169 |
+
async with client.action(event.chat_id, 'typing'):
|
170 |
+
await asyncio.sleep(1)
|
171 |
+
bot_id = await get_bot_id()
|
172 |
try:
|
173 |
+
user_id = event.chat_id
|
|
|
|
|
174 |
if event.sender_id == bot_id:
|
175 |
return
|
|
|
176 |
user_message = event.raw_text
|
|
|
177 |
if event.photo:
|
|
|
178 |
photo = await event.download_media()
|
179 |
image_description = describe_image(photo, user_message)
|
180 |
user_message += f"\n\nI sent you an image. Content of the image: {image_description}"
|
|
|
|
|
181 |
if user_message.startswith('/start') or user_message.startswith('/help') or user_message.startswith('/reset'):
|
182 |
return
|
|
|
183 |
response = await get_completion(event, user_id, user_message)
|
184 |
await event.respond(response)
|
185 |
except Exception as e:
|
|
|
200 |
""")
|
201 |
demo.launch(show_api=False)
|
202 |
|
|
|
203 |
def keep_alive():
|
204 |
ping_client = OpenAI(api_key=ping_key, base_url=api_url)
|
205 |
while True:
|
206 |
try:
|
207 |
messages = [
|
208 |
+
{"role": "system", "content": "Repeat what i say."},
|
209 |
+
{"role": "user", "content": "Repeat: 'Ping success'"}
|
210 |
]
|
211 |
request = ping_client.chat.completions.create(
|
212 |
model=model,
|
|
|
215 |
temperature=0.1,
|
216 |
top_p=0.1,
|
217 |
)
|
218 |
+
print(request.choices[0].message.content)
|
219 |
except Exception as e:
|
220 |
print(f"Keep-alive request failed: {e}")
|
221 |
+
time.sleep(1800)
|
222 |
|
223 |
if __name__ == "__main__":
|
224 |
threading.Thread(target=keep_alive).start()
|