nishantgaurav23 commited on
Commit
e7d6236
ยท
verified ยท
1 Parent(s): 6f7b9d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +338 -53
app.py CHANGED
@@ -67,6 +67,33 @@ def load_from_drive(file_id: str):
67
  st.error(f"Error loading file from Drive: {str(e)}")
68
  return None
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  @st.cache_resource(show_spinner=False)
71
  def load_llama_model():
72
  """Load Llama model with caching"""
@@ -78,20 +105,37 @@ def load_llama_model():
78
  direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
79
  download_file_with_progress(direct_url, model_path)
80
 
 
 
 
 
 
 
81
  llm_config = {
82
  "model_path": model_path,
83
  "n_ctx": 2048,
84
  "n_threads": 4,
85
  "n_batch": 512,
86
  "n_gpu_layers": 0,
87
- "verbose": False
88
  }
89
 
 
90
  model = Llama(**llm_config)
 
 
 
 
 
 
 
 
91
  st.success("Model loaded successfully!")
92
  return model
 
93
  except Exception as e:
94
- st.error(f"Error loading model: {str(e)}")
 
95
  raise
96
 
97
  def check_environment():
@@ -152,67 +196,152 @@ class RAGPipeline:
152
  logging.error(f"Error in query_model: {str(e)}")
153
  raise
154
 
155
- def process_query(self, query: str, placeholder) -> str:
156
- try:
157
- # Preprocess query
158
- query = self.preprocess_query(query)
159
 
160
- # Show retrieval status
161
- status = placeholder.empty()
162
- status.write("๐Ÿ” Finding relevant information...")
163
 
164
- # Get embeddings and search
165
- query_embedding = self.retriever.encode([query])
166
- similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
167
- scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
168
 
169
- relevant_docs = [self.documents[idx] for idx in indices.tolist()]
170
 
171
- # Update status
172
- status.write("๐Ÿ’ญ Generating response...")
173
 
174
- # Prepare context and prompt
175
- context = "\n".join(relevant_docs[:3])
176
- prompt = f"""Context information is below:
177
- {context}
178
 
179
- Given the context above, please answer the following question:
180
- {query}
181
-
182
- Guidelines:
183
- - If you cannot answer based on the context, say so politely
184
- - Keep the response concise and focused
185
- - Only include sports-related information
186
- - No dates or timestamps in the response
187
- - Use clear, natural language
188
 
189
- Answer:"""
190
 
191
- # Generate response
192
- response_placeholder = placeholder.empty()
193
 
194
- try:
195
- response_text = self.query_model(prompt)
196
- if response_text:
197
- final_response = self.postprocess_response(response_text)
198
- response_placeholder.markdown(final_response)
199
- return final_response
200
- else:
201
- message = "No relevant answer found. Please try rephrasing your question."
202
- response_placeholder.warning(message)
203
- return message
204
 
205
- except Exception as e:
206
- logging.error(f"Generation error: {str(e)}")
207
- message = "Had some trouble generating the response. Please try again."
208
- response_placeholder.warning(message)
209
- return message
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  except Exception as e:
212
- logging.error(f"Process error: {str(e)}")
213
- message = "Something went wrong. Please try again with a different question."
214
- placeholder.warning(message)
 
215
  return message
 
 
 
 
 
 
 
216
 
217
  @st.cache_resource(show_spinner=False)
218
  def initialize_rag_pipeline():
@@ -244,6 +373,132 @@ def initialize_rag_pipeline():
244
  st.error(f"Failed to initialize the system: {str(e)}")
245
  raise
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def main():
248
  try:
249
  # Environment check
@@ -333,10 +588,17 @@ def main():
333
  </p>
334
  """, unsafe_allow_html=True)
335
 
336
- # Initialize the pipeline
337
  if 'rag' not in st.session_state:
338
- with st.spinner("Loading resources..."):
339
- st.session_state.rag = initialize_rag_pipeline()
 
 
 
 
 
 
 
340
 
341
  # Create columns for layout
342
  col1, col2, col3 = st.columns([1, 6, 1])
@@ -349,10 +611,18 @@ def main():
349
  if query:
350
  response_placeholder = st.empty()
351
  try:
 
 
 
 
352
  response = st.session_state.rag.process_query(query, response_placeholder)
 
 
353
  logging.info(f"Generated response: {response}")
354
  except Exception as e:
 
355
  logging.error(f"Query processing error: {str(e)}")
 
356
  response_placeholder.warning("Unable to process your question. Please try again.")
357
  else:
358
  st.warning("Please enter a question!")
@@ -368,7 +638,22 @@ def main():
368
 
369
  except Exception as e:
370
  logging.error(f"Application error: {str(e)}")
 
371
  st.error("An unexpected error occurred. Please check the logs and try again.")
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  if __name__ == "__main__":
374
  main()
 
67
  st.error(f"Error loading file from Drive: {str(e)}")
68
  return None
69
 
70
+ # @st.cache_resource(show_spinner=False)
71
+ # def load_llama_model():
72
+ # """Load Llama model with caching"""
73
+ # try:
74
+ # model_path = "mistral-7b-v0.1.Q4_K_M.gguf"
75
+
76
+ # if not os.path.exists(model_path):
77
+ # st.info("Downloading model... This may take a while.")
78
+ # direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
79
+ # download_file_with_progress(direct_url, model_path)
80
+
81
+ # llm_config = {
82
+ # "model_path": model_path,
83
+ # "n_ctx": 2048,
84
+ # "n_threads": 4,
85
+ # "n_batch": 512,
86
+ # "n_gpu_layers": 0,
87
+ # "verbose": False
88
+ # }
89
+
90
+ # model = Llama(**llm_config)
91
+ # st.success("Model loaded successfully!")
92
+ # return model
93
+ # except Exception as e:
94
+ # st.error(f"Error loading model: {str(e)}")
95
+ # raise
96
+
97
  @st.cache_resource(show_spinner=False)
98
  def load_llama_model():
99
  """Load Llama model with caching"""
 
105
  direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
106
  download_file_with_progress(direct_url, model_path)
107
 
108
+ if not os.path.exists(model_path):
109
+ raise FileNotFoundError("Model file not found after download attempt")
110
+
111
+ if os.path.getsize(model_path) < 1000000: # Less than 1MB
112
+ raise ValueError("Model file is too small, likely corrupted")
113
+
114
  llm_config = {
115
  "model_path": model_path,
116
  "n_ctx": 2048,
117
  "n_threads": 4,
118
  "n_batch": 512,
119
  "n_gpu_layers": 0,
120
+ "verbose": True # Enable verbose mode for debugging
121
  }
122
 
123
+ logging.info("Initializing Llama model...")
124
  model = Llama(**llm_config)
125
+
126
+ # Test the model
127
+ logging.info("Testing model...")
128
+ test_response = model("Test", max_tokens=10)
129
+ if not test_response:
130
+ raise RuntimeError("Model test failed")
131
+
132
+ logging.info("Model loaded and tested successfully")
133
  st.success("Model loaded successfully!")
134
  return model
135
+
136
  except Exception as e:
137
+ logging.error(f"Error loading model: {str(e)}")
138
+ logging.error("Full error details: ", exc_info=True)
139
  raise
140
 
141
  def check_environment():
 
196
  logging.error(f"Error in query_model: {str(e)}")
197
  raise
198
 
199
+ # def process_query(self, query: str, placeholder) -> str:
200
+ # try:
201
+ # # Preprocess query
202
+ # query = self.preprocess_query(query)
203
 
204
+ # # Show retrieval status
205
+ # status = placeholder.empty()
206
+ # status.write("๐Ÿ” Finding relevant information...")
207
 
208
+ # # Get embeddings and search
209
+ # query_embedding = self.retriever.encode([query])
210
+ # similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
211
+ # scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
212
 
213
+ # relevant_docs = [self.documents[idx] for idx in indices.tolist()]
214
 
215
+ # # Update status
216
+ # status.write("๐Ÿ’ญ Generating response...")
217
 
218
+ # # Prepare context and prompt
219
+ # context = "\n".join(relevant_docs[:3])
220
+ # prompt = f"""Context information is below:
221
+ # {context}
222
 
223
+ # Given the context above, please answer the following question:
224
+ # {query}
225
+
226
+ # Guidelines:
227
+ # - If you cannot answer based on the context, say so politely
228
+ # - Keep the response concise and focused
229
+ # - Only include sports-related information
230
+ # - No dates or timestamps in the response
231
+ # - Use clear, natural language
232
 
233
+ # Answer:"""
234
 
235
+ # # Generate response
236
+ # response_placeholder = placeholder.empty()
237
 
238
+ # try:
239
+ # response_text = self.query_model(prompt)
240
+ # if response_text:
241
+ # final_response = self.postprocess_response(response_text)
242
+ # response_placeholder.markdown(final_response)
243
+ # return final_response
244
+ # else:
245
+ # message = "No relevant answer found. Please try rephrasing your question."
246
+ # response_placeholder.warning(message)
247
+ # return message
248
 
249
+ # except Exception as e:
250
+ # logging.error(f"Generation error: {str(e)}")
251
+ # message = "Had some trouble generating the response. Please try again."
252
+ # response_placeholder.warning(message)
253
+ # return message
254
 
255
+ # except Exception as e:
256
+ # logging.error(f"Process error: {str(e)}")
257
+ # message = "Something went wrong. Please try again with a different question."
258
+ # placeholder.warning(message)
259
+ # return message
260
+
261
+ def process_query(self, query: str, placeholder) -> str:
262
+ try:
263
+ # Preprocess query
264
+ query = self.preprocess_query(query)
265
+ logging.info(f"Processing query: {query}")
266
+
267
+ # Show retrieval status
268
+ status = placeholder.empty()
269
+ status.write("๐Ÿ” Finding relevant information...")
270
+
271
+ # Get embeddings and search
272
+ query_embedding = self.retriever.encode([query])
273
+ similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
274
+ scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
275
+
276
+ # Log similarity scores
277
+ for idx, score in zip(indices.tolist(), scores.tolist()):
278
+ logging.info(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...")
279
+
280
+ relevant_docs = [self.documents[idx] for idx in indices.tolist()]
281
+
282
+ # Update status
283
+ status.write("๐Ÿ’ญ Generating response...")
284
+
285
+ # Prepare context and prompt
286
+ context = "\n".join(relevant_docs[:3])
287
+ prompt = f"""Context information is below:
288
+ {context}
289
+
290
+ Given the context above, please answer the following question:
291
+ {query}
292
+
293
+ Guidelines:
294
+ - If you cannot answer based on the context, say so politely
295
+ - Keep the response concise and focused
296
+ - Only include sports-related information
297
+ - No dates or timestamps in the response
298
+ - Use clear, natural language
299
+
300
+ Answer:"""
301
+
302
+ # Generate response
303
+ response_placeholder = placeholder.empty()
304
+
305
+ try:
306
+ # Add logging for model state
307
+ logging.info("Model state check - Is None?: " + str(self.llm is None))
308
+
309
+ # Directly use Llama model
310
+ response = self.llm(
311
+ prompt,
312
+ max_tokens=512,
313
+ temperature=0.4,
314
+ top_p=0.95,
315
+ echo=False,
316
+ stop=["Question:", "\n\n"]
317
+ )
318
+
319
+ logging.info(f"Raw model response: {response}")
320
+
321
+ if response and isinstance(response, dict) and 'choices' in response:
322
+ generated_text = response['choices'][0].get('text', '').strip()
323
+ if generated_text:
324
+ final_response = self.postprocess_response(generated_text)
325
+ response_placeholder.markdown(final_response)
326
+ return final_response
327
+
328
+ message = "No relevant answer found. Please try rephrasing your question."
329
+ response_placeholder.warning(message)
330
+ return message
331
+
332
  except Exception as e:
333
+ logging.error(f"Generation error: {str(e)}")
334
+ logging.error(f"Full error details: ", exc_info=True)
335
+ message = f"Had some trouble generating the response: {str(e)}"
336
+ response_placeholder.warning(message)
337
  return message
338
+
339
+ except Exception as e:
340
+ logging.error(f"Process error: {str(e)}")
341
+ logging.error(f"Full error details: ", exc_info=True)
342
+ message = f"Something went wrong: {str(e)}"
343
+ placeholder.warning(message)
344
+ return message
345
 
346
  @st.cache_resource(show_spinner=False)
347
  def initialize_rag_pipeline():
 
373
  st.error(f"Failed to initialize the system: {str(e)}")
374
  raise
375
 
376
+ # def main():
377
+ # try:
378
+ # # Environment check
379
+ # if not check_environment():
380
+ # return
381
+
382
+ # # Improved CSS styling
383
+ # st.markdown("""
384
+ # <style>
385
+ # /* Container styling */
386
+ # .block-container {
387
+ # padding-top: 2rem;
388
+ # padding-bottom: 2rem;
389
+ # }
390
+
391
+ # /* Text input styling */
392
+ # .stTextInput > div > div > input {
393
+ # width: 100%;
394
+ # }
395
+
396
+ # /* Button styling */
397
+ # .stButton > button {
398
+ # width: 200px;
399
+ # margin: 0 auto;
400
+ # display: block;
401
+ # background-color: #FF4B4B;
402
+ # color: white;
403
+ # border-radius: 5px;
404
+ # padding: 0.5rem 1rem;
405
+ # }
406
+
407
+ # /* Title styling */
408
+ # .main-title {
409
+ # text-align: center;
410
+ # padding: 1rem 0;
411
+ # font-size: 3rem;
412
+ # color: #1F1F1F;
413
+ # }
414
+
415
+ # .sub-title {
416
+ # text-align: center;
417
+ # padding: 0.5rem 0;
418
+ # font-size: 1.5rem;
419
+ # color: #4F4F4F;
420
+ # }
421
+
422
+ # /* Description styling */
423
+ # .description {
424
+ # text-align: center;
425
+ # color: #666666;
426
+ # padding: 0.5rem 0;
427
+ # font-size: 1.1rem;
428
+ # line-height: 1.6;
429
+ # margin-bottom: 1rem;
430
+ # }
431
+
432
+ # /* Answer container styling */
433
+ # .stMarkdown {
434
+ # max-width: 100%;
435
+ # }
436
+
437
+ # /* Streamlit default overrides */
438
+ # .st-emotion-cache-16idsys p {
439
+ # font-size: 1.1rem;
440
+ # line-height: 1.6;
441
+ # }
442
+
443
+ # /* Container for main content */
444
+ # .main-content {
445
+ # max-width: 1200px;
446
+ # margin: 0 auto;
447
+ # padding: 0 1rem;
448
+ # }
449
+ # </style>
450
+ # """, unsafe_allow_html=True)
451
+
452
+ # # Header section
453
+ # st.markdown("<h1 class='main-title'>๐Ÿ† The Sport Chatbot</h1>", unsafe_allow_html=True)
454
+ # st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
455
+ # st.markdown("""
456
+ # <p class='description'>
457
+ # Hey there! ๐Ÿ‘‹ I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball.
458
+ # With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.
459
+ # </p>
460
+ # <p class='description'>
461
+ # Got any general questions? Feel free to askโ€”I'll do my best to provide answers based on the information I've been trained on!
462
+ # </p>
463
+ # """, unsafe_allow_html=True)
464
+
465
+ # # Initialize the pipeline
466
+ # if 'rag' not in st.session_state:
467
+ # with st.spinner("Loading resources..."):
468
+ # st.session_state.rag = initialize_rag_pipeline()
469
+
470
+ # # Create columns for layout
471
+ # col1, col2, col3 = st.columns([1, 6, 1])
472
+
473
+ # with col2:
474
+ # # Query input
475
+ # query = st.text_input("What would you like to know about sports?")
476
+
477
+ # if st.button("Get Answer"):
478
+ # if query:
479
+ # response_placeholder = st.empty()
480
+ # try:
481
+ # response = st.session_state.rag.process_query(query, response_placeholder)
482
+ # logging.info(f"Generated response: {response}")
483
+ # except Exception as e:
484
+ # logging.error(f"Query processing error: {str(e)}")
485
+ # response_placeholder.warning("Unable to process your question. Please try again.")
486
+ # else:
487
+ # st.warning("Please enter a question!")
488
+
489
+ # # Footer
490
+ # st.markdown("<br><br>", unsafe_allow_html=True)
491
+ # st.markdown("---")
492
+ # st.markdown("""
493
+ # <p style='text-align: center; color: #666666; padding: 1rem 0;'>
494
+ # Powered by ESPN Data & Mistral AI ๐Ÿš€
495
+ # </p>
496
+ # """, unsafe_allow_html=True)
497
+
498
+ # except Exception as e:
499
+ # logging.error(f"Application error: {str(e)}")
500
+ # st.error("An unexpected error occurred. Please check the logs and try again.")
501
+
502
  def main():
503
  try:
504
  # Environment check
 
588
  </p>
589
  """, unsafe_allow_html=True)
590
 
591
+ # Initialize the pipeline with better error handling
592
  if 'rag' not in st.session_state:
593
+ try:
594
+ with st.spinner("Loading resources..."):
595
+ st.session_state.rag = initialize_rag_pipeline()
596
+ logging.info("Pipeline initialized successfully")
597
+ except Exception as e:
598
+ logging.error(f"Pipeline initialization error: {str(e)}")
599
+ st.error("Failed to initialize the system. Please check the logs.")
600
+ st.stop()
601
+ return
602
 
603
  # Create columns for layout
604
  col1, col2, col3 = st.columns([1, 6, 1])
 
611
  if query:
612
  response_placeholder = st.empty()
613
  try:
614
+ # Log query processing start
615
+ logging.info(f"Processing query: {query}")
616
+
617
+ # Process query and get response
618
  response = st.session_state.rag.process_query(query, response_placeholder)
619
+
620
+ # Log successful response
621
  logging.info(f"Generated response: {response}")
622
  except Exception as e:
623
+ # Log error details
624
  logging.error(f"Query processing error: {str(e)}")
625
+ logging.error("Full error details: ", exc_info=True)
626
  response_placeholder.warning("Unable to process your question. Please try again.")
627
  else:
628
  st.warning("Please enter a question!")
 
638
 
639
  except Exception as e:
640
  logging.error(f"Application error: {str(e)}")
641
+ logging.error("Full error details: ", exc_info=True)
642
  st.error("An unexpected error occurred. Please check the logs and try again.")
643
 
644
+ if __name__ == "__main__":
645
+ # Configure logging
646
+ logging.basicConfig(
647
+ level=logging.INFO,
648
+ format='%(asctime)s - %(levelname)s - %(message)s'
649
+ )
650
+
651
+ try:
652
+ main()
653
+ except Exception as e:
654
+ logging.error(f"Fatal error: {str(e)}")
655
+ logging.error("Full error details: ", exc_info=True)
656
+ st.error("A fatal error occurred. Please check the logs and try again.")
657
+
658
  if __name__ == "__main__":
659
  main()