rbn2008k commited on
Commit
4cf79aa
·
verified ·
1 Parent(s): d20890c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -58
app.py CHANGED
@@ -7,7 +7,6 @@ import gradio as gr
7
  import asyncio
8
  from collections import OrderedDict
9
  from datetime import datetime
10
- from transformers import AutoProcessor, AutoTokenizer
11
  import requests
12
  from openai import OpenAI
13
  from telethon import TelegramClient, events
@@ -16,14 +15,12 @@ from huggingface_hub import InferenceClient
16
  import pymongo
17
  from pymongo import MongoClient
18
 
19
- # Load system prompt from a file
20
  def load_system_prompt():
21
  with open('prompt.txt', 'r') as file:
22
  return file.read()
23
 
24
  system_prompt = load_system_prompt()
25
 
26
- # Environment variables
27
  api_id = os.getenv('api_id')
28
  api_hash = os.getenv('api_hash')
29
  bot_token = os.getenv('bot_token')
@@ -35,24 +32,22 @@ model1 = os.getenv('model1')
35
  model2 = os.getenv('model2')
36
  mongoURI = os.getenv('MONGO_URI')
37
 
38
- # Initialize OpenAI and MongoDB clients
39
  openai_client = OpenAI(api_key=openai_api_key, base_url=api_url)
40
  mongo_client = MongoClient(mongoURI)
41
  db = mongo_client['Scarlett']
42
  chat_collection = db['chats']
43
 
44
- # Initialize Hugging Face models for image processing
 
 
45
 
46
- # Local cache for up to 5 users
47
  local_chat_history = OrderedDict()
48
  MAX_LOCAL_USERS = 5
49
 
50
- # Retrieve chat history from MongoDB
51
  def get_history_from_mongo(user_id):
52
  result = chat_collection.find_one({"user_id": user_id})
53
  return result.get("messages", []) if result else []
54
 
55
- # Store message in MongoDB (limit to last 99 messages)
56
  def store_message_in_mongo(user_id, role, content):
57
  chat_collection.update_one(
58
  {"user_id": user_id},
@@ -67,35 +62,25 @@ def store_message_in_mongo(user_id, role, content):
67
  upsert=True
68
  )
69
 
70
- # Get chat history from local cache or MongoDB
71
  def get_chat_history(user_id):
72
  if user_id in local_chat_history:
73
- local_chat_history.move_to_end(user_id) # Mark as most recently used
74
  return local_chat_history[user_id]
75
-
76
- # Load from MongoDB if not in local cache
77
  history = get_history_from_mongo(user_id)
78
  local_chat_history[user_id] = history
79
-
80
  if len(local_chat_history) > MAX_LOCAL_USERS:
81
- local_chat_history.popitem(last=False) # Remove LRU user
82
-
83
  return history
84
 
85
- # Update chat history (both local and MongoDB)
86
  def update_chat_history(user_id, role, content):
87
  if user_id not in local_chat_history:
88
  local_chat_history[user_id] = get_history_from_mongo(user_id)
89
-
90
  local_chat_history[user_id].append({"role": role, "content": content})
91
  local_chat_history.move_to_end(user_id)
92
-
93
  if len(local_chat_history) > MAX_LOCAL_USERS:
94
  local_chat_history.popitem(last=False)
95
-
96
  store_message_in_mongo(user_id, role, content)
97
 
98
- # Encode image to base64
99
  def encode_local_image(image_path):
100
  im = Image.fromarray(image_path)
101
  buffered = BytesIO()
@@ -104,18 +89,13 @@ def encode_local_image(image_path):
104
  image_base64 = base64.b64encode(image_bytes).decode('ascii')
105
  return image_base64
106
 
107
- # Describe image using the model with error handling
108
  def inference_calling_idefics(image_path, question=""):
109
  system_prompt = os.getenv('USER_PROMPT')
110
- model_id = "HuggingFaceM4/idefics2-8b-chatty"
111
  client = InferenceClient(model=model_id)
112
- image_base64 = base64_encoded_image(image_path)
113
  image_info = f"data:image/png;base64,{image_base64}"
114
-
115
- # Include the system prompt before the user question
116
  prompt = f"{system_prompt}\n![]({image_info})\n{question}\n\n"
117
-
118
- # Adding do_sample=True to ensure different responses and max_tokens
119
  response = client.text_generation(
120
  prompt,
121
  max_new_tokens=512,
@@ -124,25 +104,20 @@ def inference_calling_idefics(image_path, question=""):
124
  )
125
  return response
126
 
127
- # Function to generate answers using the idefics model with system prompt and max_tokens
128
  def describe_image(image_path, question=""):
129
  try:
130
  answer = inference_calling_idefics(image_path, question)
131
  return answer
132
  except Exception as e:
133
  print(e)
134
- answer = "Error while seeing the image."
135
- return answer
136
 
137
- # Telegram bot client
138
  client = TelegramClient('bot', api_id, api_hash).start(bot_token=bot_token)
139
 
140
- # Get bot user ID for message filtering
141
  async def get_bot_id():
142
  me = await client.get_me()
143
  return me.id
144
 
145
- # Async function to get OpenAI completion
146
  async def get_completion(event, user_id, prompt):
147
  async with client.action(event.chat_id, 'typing'):
148
  await asyncio.sleep(3)
@@ -152,9 +127,8 @@ async def get_completion(event, user_id, prompt):
152
  *history,
153
  {"role": "user", "content": prompt},
154
  ]
155
-
156
  try:
157
- response = openai_client.chat.completions.create(
158
  model=model,
159
  messages=messages,
160
  max_tokens=512,
@@ -162,17 +136,18 @@ async def get_completion(event, user_id, prompt):
162
  top_p=1.0,
163
  frequency_penalty=0.9,
164
  presence_penalty=0.9,
 
165
  )
166
- message = response.choices[0].message.content
 
 
167
  except Exception as e:
168
  message = f"Whoops!"
169
  print(e)
170
-
171
- update_chat_history(user_id, "user", prompt) # Update history
172
- update_chat_history(user_id, "assistant", message) # Update assistant's response
173
  return message
174
 
175
- # Telegram bot events
176
  @client.on(events.NewMessage(pattern='/start'))
177
  async def start(event):
178
  await event.respond("Hello!")
@@ -184,34 +159,27 @@ async def help(event):
184
  @client.on(events.NewMessage(pattern='/reset'))
185
  async def reset(event):
186
  user_id = event.chat_id
187
- chat_collection.delete_one({"user_id": user_id}) # Reset MongoDB chat history for the user
188
  if user_id in local_chat_history:
189
- del local_chat_history[user_id] # Remove from local cache if present
190
  await event.respond("History has been reset.")
191
 
192
  @client.on(events.NewMessage)
193
  async def handle_message(event):
194
- bot_id = await get_bot_id() # Get bot ID to avoid responding to itself
195
-
 
196
  try:
197
- user_id = event.chat_id # Use chat_id to distinguish between users
198
-
199
- # Ignore messages from the bot itself
200
  if event.sender_id == bot_id:
201
  return
202
-
203
  user_message = event.raw_text
204
-
205
  if event.photo:
206
- # If an image is sent, describe the image
207
  photo = await event.download_media()
208
  image_description = describe_image(photo, user_message)
209
  user_message += f"\n\nI sent you an image. Content of the image: {image_description}"
210
-
211
- # Ignore command messages to prevent double processing
212
  if user_message.startswith('/start') or user_message.startswith('/help') or user_message.startswith('/reset'):
213
  return
214
-
215
  response = await get_completion(event, user_id, user_message)
216
  await event.respond(response)
217
  except Exception as e:
@@ -232,14 +200,13 @@ def launch_gradio():
232
  """)
233
  demo.launch(show_api=False)
234
 
235
- # Keep-alive function to keep the bot running
236
  def keep_alive():
237
  ping_client = OpenAI(api_key=ping_key, base_url=api_url)
238
  while True:
239
  try:
240
  messages = [
241
- {"role": "system", "content": "Be a helpful assistant."},
242
- {"role": "user", "content": "Hello"}
243
  ]
244
  request = ping_client.chat.completions.create(
245
  model=model,
@@ -248,10 +215,10 @@ def keep_alive():
248
  temperature=0.1,
249
  top_p=0.1,
250
  )
251
- print(request)
252
  except Exception as e:
253
  print(f"Keep-alive request failed: {e}")
254
- time.sleep(1800) # Ping every 30 minutes
255
 
256
  if __name__ == "__main__":
257
  threading.Thread(target=keep_alive).start()
 
7
  import asyncio
8
  from collections import OrderedDict
9
  from datetime import datetime
 
10
  import requests
11
  from openai import OpenAI
12
  from telethon import TelegramClient, events
 
15
  import pymongo
16
  from pymongo import MongoClient
17
 
 
18
  def load_system_prompt():
19
  with open('prompt.txt', 'r') as file:
20
  return file.read()
21
 
22
  system_prompt = load_system_prompt()
23
 
 
24
  api_id = os.getenv('api_id')
25
  api_hash = os.getenv('api_hash')
26
  bot_token = os.getenv('bot_token')
 
32
  model2 = os.getenv('model2')
33
  mongoURI = os.getenv('MONGO_URI')
34
 
 
35
  openai_client = OpenAI(api_key=openai_api_key, base_url=api_url)
36
  mongo_client = MongoClient(mongoURI)
37
  db = mongo_client['Scarlett']
38
  chat_collection = db['chats']
39
 
40
+ idefics_processor = AutoProcessor.from_pretrained(model1)
41
+ idefics_client = InferenceClient(model2)
42
+ tokenizer = AutoTokenizer.from_pretrained(model1)
43
 
 
44
  local_chat_history = OrderedDict()
45
  MAX_LOCAL_USERS = 5
46
 
 
47
  def get_history_from_mongo(user_id):
48
  result = chat_collection.find_one({"user_id": user_id})
49
  return result.get("messages", []) if result else []
50
 
 
51
  def store_message_in_mongo(user_id, role, content):
52
  chat_collection.update_one(
53
  {"user_id": user_id},
 
62
  upsert=True
63
  )
64
 
 
65
  def get_chat_history(user_id):
66
  if user_id in local_chat_history:
67
+ local_chat_history.move_to_end(user_id)
68
  return local_chat_history[user_id]
 
 
69
  history = get_history_from_mongo(user_id)
70
  local_chat_history[user_id] = history
 
71
  if len(local_chat_history) > MAX_LOCAL_USERS:
72
+ local_chat_history.popitem(last=False)
 
73
  return history
74
 
 
75
  def update_chat_history(user_id, role, content):
76
  if user_id not in local_chat_history:
77
  local_chat_history[user_id] = get_history_from_mongo(user_id)
 
78
  local_chat_history[user_id].append({"role": role, "content": content})
79
  local_chat_history.move_to_end(user_id)
 
80
  if len(local_chat_history) > MAX_LOCAL_USERS:
81
  local_chat_history.popitem(last=False)
 
82
  store_message_in_mongo(user_id, role, content)
83
 
 
84
  def encode_local_image(image_path):
85
  im = Image.fromarray(image_path)
86
  buffered = BytesIO()
 
89
  image_base64 = base64.b64encode(image_bytes).decode('ascii')
90
  return image_base64
91
 
 
92
  def inference_calling_idefics(image_path, question=""):
93
  system_prompt = os.getenv('USER_PROMPT')
94
+ model_id = model2
95
  client = InferenceClient(model=model_id)
96
+ image_base64 = describe_image(image_path)
97
  image_info = f"data:image/png;base64,{image_base64}"
 
 
98
  prompt = f"{system_prompt}\n![]({image_info})\n{question}\n\n"
 
 
99
  response = client.text_generation(
100
  prompt,
101
  max_new_tokens=512,
 
104
  )
105
  return response
106
 
 
107
  def describe_image(image_path, question=""):
108
  try:
109
  answer = inference_calling_idefics(image_path, question)
110
  return answer
111
  except Exception as e:
112
  print(e)
113
+ return "Error while seeing the image."
 
114
 
 
115
  client = TelegramClient('bot', api_id, api_hash).start(bot_token=bot_token)
116
 
 
117
  async def get_bot_id():
118
  me = await client.get_me()
119
  return me.id
120
 
 
121
  async def get_completion(event, user_id, prompt):
122
  async with client.action(event.chat_id, 'typing'):
123
  await asyncio.sleep(3)
 
127
  *history,
128
  {"role": "user", "content": prompt},
129
  ]
 
130
  try:
131
+ completion = openai_client.chat.completions.create(
132
  model=model,
133
  messages=messages,
134
  max_tokens=512,
 
136
  top_p=1.0,
137
  frequency_penalty=0.9,
138
  presence_penalty=0.9,
139
+ stream=True
140
  )
141
+ for chunk in completion:
142
+ if chunk.choices[0].delta.content is not None:
143
+ message += chunk.choices[0].delta.content
144
  except Exception as e:
145
  message = f"Whoops!"
146
  print(e)
147
+ update_chat_history(user_id, "user", prompt)
148
+ update_chat_history(user_id, "assistant", message)
 
149
  return message
150
 
 
151
  @client.on(events.NewMessage(pattern='/start'))
152
  async def start(event):
153
  await event.respond("Hello!")
 
159
  @client.on(events.NewMessage(pattern='/reset'))
160
  async def reset(event):
161
  user_id = event.chat_id
162
+ chat_collection.delete_one({"user_id": user_id})
163
  if user_id in local_chat_history:
164
+ del local_chat_history[user_id]
165
  await event.respond("History has been reset.")
166
 
167
  @client.on(events.NewMessage)
168
  async def handle_message(event):
169
+ async with client.action(event.chat_id, 'typing'):
170
+ await asyncio.sleep(1)
171
+ bot_id = await get_bot_id()
172
  try:
173
+ user_id = event.chat_id
 
 
174
  if event.sender_id == bot_id:
175
  return
 
176
  user_message = event.raw_text
 
177
  if event.photo:
 
178
  photo = await event.download_media()
179
  image_description = describe_image(photo, user_message)
180
  user_message += f"\n\nI sent you an image. Content of the image: {image_description}"
 
 
181
  if user_message.startswith('/start') or user_message.startswith('/help') or user_message.startswith('/reset'):
182
  return
 
183
  response = await get_completion(event, user_id, user_message)
184
  await event.respond(response)
185
  except Exception as e:
 
200
  """)
201
  demo.launch(show_api=False)
202
 
 
203
  def keep_alive():
204
  ping_client = OpenAI(api_key=ping_key, base_url=api_url)
205
  while True:
206
  try:
207
  messages = [
208
+ {"role": "system", "content": "Repeat what i say."},
209
+ {"role": "user", "content": "Repeat: 'Ping success'"}
210
  ]
211
  request = ping_client.chat.completions.create(
212
  model=model,
 
215
  temperature=0.1,
216
  top_p=0.1,
217
  )
218
+ print(request.choices[0].message.content)
219
  except Exception as e:
220
  print(f"Keep-alive request failed: {e}")
221
+ time.sleep(1800)
222
 
223
  if __name__ == "__main__":
224
  threading.Thread(target=keep_alive).start()