nishantgaurav23 commited on
Commit
6f7b9d9
β€’
1 Parent(s): bc26371

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -316
app.py CHANGED
@@ -7,24 +7,14 @@ import torch
7
  import torch.nn.functional as F
8
  import re
9
  import requests
10
- #from dotenv import load_dotenv
11
  from embedding_processor import SentenceTransformerRetriever, process_data
12
  import pickle
13
-
14
- import os
15
- import warnings
16
- import json # Add this import
17
-
18
- # Add at the top with other imports
19
- from llama_cpp import Llama
20
- import requests
21
- from tqdm import tqdm
22
-
23
-
24
  import logging
25
  import sys
 
 
26
 
27
- # Set page config immediately after imports
28
  st.set_page_config(
29
  page_title="The Sport Chatbot",
30
  page_icon="πŸ†",
@@ -38,16 +28,21 @@ logging.basicConfig(
38
  handlers=[logging.StreamHandler(sys.stdout)]
39
  )
40
 
41
- # Create necessary directories at startup
42
- for directory in ['models', 'ESPN_data', 'embeddings_cache']:
43
- os.makedirs(directory, exist_ok=True)
44
-
45
-
46
-
47
- # Load environment variables
48
- #load_dotenv()
49
-
50
- # Add the new function here, right after imports and before API configuration
 
 
 
 
 
51
 
52
  @st.cache_data
53
  def load_from_drive(file_id: str):
@@ -72,93 +67,72 @@ def load_from_drive(file_id: str):
72
  st.error(f"Error loading file from Drive: {str(e)}")
73
  return None
74
 
75
- # Hugging Face API configuration
76
-
77
- # API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
78
- # headers = {"Authorization": f"Bearer HF_TOKEN"}
79
- #model_name = 'mistralai/Mistral-7B-v0.1'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  class RAGPipeline:
83
-
84
  def __init__(self, data_folder: str, k: int = 5):
85
- try:
86
- self.data_folder = data_folder
87
- self.k = k
88
- self.retriever = SentenceTransformerRetriever()
89
- self.documents = []
90
- self.device = torch.device("cpu")
91
-
92
- # Model path with absolute path
93
- current_dir = os.path.dirname(os.path.abspath(__file__))
94
- self.model_path = os.path.join(current_dir, "models", "mistral-7b-v0.1.Q4_K_M.gguf")
95
-
96
- # Initialize model
97
- self.llm = self.get_model()
98
-
99
- except Exception as e:
100
- logging.error(f"Error in RAGPipeline initialization: {str(e)}")
101
- raise
102
-
103
- @st.cache_resource(show_spinner=False)
104
- def get_model(_self):
105
- """Get or initialize the model with caching"""
106
- try:
107
- if not os.path.exists(_self.model_path):
108
- os.makedirs(os.path.dirname(_self.model_path), exist_ok=True)
109
- st.info("Downloading model... This may take a while.")
110
- direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
111
- _self.download_file_with_progress(direct_url, _self.model_path)
112
-
113
- # Verify file exists and has content
114
- if not os.path.exists(_self.model_path):
115
- raise FileNotFoundError(f"Model file {_self.model_path} not found after download attempts")
116
-
117
- if os.path.getsize(_self.model_path) < 1000000: # Less than 1MB
118
- os.remove(_self.model_path)
119
- raise ValueError("Downloaded model file is too small, likely corrupted")
120
 
121
- llm_config = {
122
- "model_path": _self.model_path,
123
- "n_ctx": 2048,
124
- "n_threads": 4,
125
- "n_batch": 512,
126
- "n_gpu_layers": 0,
127
- "verbose": False
128
- }
129
-
130
- model = Llama(**llm_config)
131
- st.success("Model loaded successfully!")
132
- return model
133
-
134
- except Exception as e:
135
- st.error(f"Error initializing model: {str(e)}")
136
- raise
137
 
138
- def download_file_with_progress(self, url: str, filename: str):
139
- """Download a file with progress bar using requests"""
140
- response = requests.get(url, stream=True)
141
- total_size = int(response.headers.get('content-length', 0))
142
-
143
- with open(filename, 'wb') as file, tqdm(
144
- desc=filename,
145
- total=total_size,
146
- unit='iB',
147
- unit_scale=True,
148
- unit_divisor=1024,
149
- ) as progress_bar:
150
- for data in response.iter_content(chunk_size=1024):
151
- size = file.write(data)
152
- progress_bar.update(size)
153
 
154
- # Alternative API call with streaming
155
  def query_model(self, prompt: str) -> str:
156
- """Query the local Llama model instead of API"""
157
  try:
158
  if self.llm is None:
159
  raise RuntimeError("Model not initialized")
160
-
161
- # Generate response using Llama model
162
  response = self.llm(
163
  prompt,
164
  max_tokens=512,
@@ -167,47 +141,41 @@ class RAGPipeline:
167
  echo=False,
168
  stop=["Question:", "\n\n"]
169
  )
170
-
171
- # Check and extract response
172
  if response and 'choices' in response and len(response['choices']) > 0:
173
  text = response['choices'][0].get('text', '').strip()
174
  return text
175
  else:
176
  raise ValueError("No valid response generated")
177
-
178
  except Exception as e:
179
  logging.error(f"Error in query_model: {str(e)}")
180
  raise
181
- def preprocess_query(self, query: str) -> str:
182
- """Clean and prepare the query"""
183
- query = query.lower().strip()
184
- query = re.sub(r'\s+', ' ', query)
185
- return query
186
-
187
  def process_query(self, query: str, placeholder) -> str:
188
  try:
189
  # Preprocess query
190
  query = self.preprocess_query(query)
191
-
192
  # Show retrieval status
193
  status = placeholder.empty()
194
  status.write("πŸ” Finding relevant information...")
195
-
196
  # Get embeddings and search
197
  query_embedding = self.retriever.encode([query])
198
  similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
199
  scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
200
-
201
  relevant_docs = [self.documents[idx] for idx in indices.tolist()]
202
-
203
  # Update status
204
  status.write("πŸ’­ Generating response...")
205
-
206
  # Prepare context and prompt
207
- context = "\n".join(relevant_docs[:3]) # Use top 3 most relevant docs
208
  prompt = f"""Context information is below:
209
  {context}
210
-
211
  Given the context above, please answer the following question:
212
  {query}
213
 
@@ -217,12 +185,12 @@ class RAGPipeline:
217
  - Only include sports-related information
218
  - No dates or timestamps in the response
219
  - Use clear, natural language
220
-
221
  Answer:"""
222
-
223
  # Generate response
224
  response_placeholder = placeholder.empty()
225
-
226
  try:
227
  response_text = self.query_model(prompt)
228
  if response_text:
@@ -233,174 +201,27 @@ class RAGPipeline:
233
  message = "No relevant answer found. Please try rephrasing your question."
234
  response_placeholder.warning(message)
235
  return message
236
-
237
  except Exception as e:
238
  logging.error(f"Generation error: {str(e)}")
239
  message = "Had some trouble generating the response. Please try again."
240
  response_placeholder.warning(message)
241
  return message
242
-
243
  except Exception as e:
244
  logging.error(f"Process error: {str(e)}")
245
  message = "Something went wrong. Please try again with a different question."
246
  placeholder.warning(message)
247
  return message
248
 
249
- def postprocess_response(self, response: str) -> str:
250
- """Clean up the generated response"""
251
- response = response.strip()
252
- response = re.sub(r'\s+', ' ', response)
253
- response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response)
254
- return response
255
-
256
-
257
- # def process_query(self, query: str, placeholder) -> str:
258
- # try:
259
- # # Preprocess query
260
- # query = self.preprocess_query(query)
261
-
262
- # # Show retrieval status
263
- # status = placeholder.empty()
264
- # status.write("πŸ” Finding relevant information...")
265
-
266
- # # Get embeddings and search using tensor operations
267
- # query_embedding = self.retriever.encode([query])
268
- # similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
269
- # scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
270
-
271
- # # Print search results for debugging
272
- # print("\nSearch Results:")
273
- # for idx, score in zip(indices.tolist(), scores.tolist()):
274
- # print(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...")
275
-
276
- # relevant_docs = [self.documents[idx] for idx in indices.tolist()]
277
-
278
- # # Update status
279
- # status.write("πŸ’­ Generating response...")
280
-
281
- # # Prepare context and prompt
282
- # context = "\n".join(relevant_docs[:3]) # Only use top 3 most relevant docs
283
- # prompt = f"""Answer this question using the given context. Be specific and detailed.
284
-
285
- # Context: {context}
286
-
287
- # Question: {query}
288
-
289
- # Answer (provide a complete, detailed response):"""
290
-
291
- # # Generate response
292
- # response_placeholder = placeholder.empty()
293
-
294
- # try:
295
- # response = requests.post(
296
- # model_name,
297
- # #headers=headers,
298
- # json={
299
- # "inputs": prompt,
300
- # "parameters": {
301
- # "max_new_tokens": 1024,
302
- # "temperature": 0.5,
303
- # "top_p": 0.9,
304
- # "top_k": 50,
305
- # "repetition_penalty": 1.03,
306
- # "do_sample": True
307
- # }
308
- # },
309
- # timeout=30
310
- # ).json()
311
-
312
- # if response and isinstance(response, list) and len(response) > 0:
313
- # generated_text = response[0].get('generated_text', '').strip()
314
- # if generated_text:
315
- # # Find and extract only the answer part
316
- # if "Answer:" in generated_text:
317
- # answer_part = generated_text.split("Answer:")[-1].strip()
318
- # elif "Answer (provide a complete, detailed response):" in generated_text:
319
- # answer_part = generated_text.split("Answer (provide a complete, detailed response):")[-1].strip()
320
- # else:
321
- # answer_part = generated_text.strip()
322
-
323
- # # Clean up the answer
324
- # answer_part = answer_part.replace("Context:", "").replace("Question:", "")
325
-
326
- # final_response = self.postprocess_response(answer_part)
327
- # response_placeholder.markdown(final_response)
328
- # return final_response
329
-
330
- # message = "No relevant answer found. Please try rephrasing your question."
331
- # response_placeholder.warning(message)
332
- # return message
333
-
334
- # except Exception as e:
335
- # print(f"Generation error: {str(e)}")
336
- # message = "Had some trouble generating the response. Please try again."
337
- # response_placeholder.warning(message)
338
- # return message
339
-
340
- # except Exception as e:
341
- # print(f"Process error: {str(e)}")
342
- # message = "Something went wrong. Please try again with a different question."
343
- # placeholder.warning(message)
344
- # return message
345
- def check_environment():
346
- """Check if the environment is properly set up"""
347
- # if not headers['Authorization']:
348
- # st.error("HUGGINGFACE_API_KEY environment variable not set!")
349
- # st.stop()
350
- # return False
351
-
352
- try:
353
- import torch
354
- import sentence_transformers
355
- return True
356
- except ImportError as e:
357
- st.error(f"Missing required package: {str(e)}")
358
- st.stop()
359
- return False
360
-
361
- # @st.cache_resource
362
- # def initialize_rag_pipeline():
363
- # """Initialize the RAG pipeline once"""
364
- # data_folder = "ESPN_data"
365
- # return RAGPipeline(data_folder)
366
- def check_space_requirements():
367
- """Check if we're running on HF Space and have necessary resources"""
368
- try:
369
- # Check if we're on HF Space
370
- is_space = os.environ.get('SPACE_ID') is not None
371
-
372
- if is_space:
373
- # Check disk space
374
- disk_space = os.statvfs('/')
375
- free_space_gb = (disk_space.f_frsize * disk_space.f_bavail) / (1024**3)
376
-
377
- if free_space_gb < 10: # Need at least 10GB free
378
- st.warning(f"Low disk space: {free_space_gb:.1f}GB free")
379
-
380
- # Check if model exists
381
- model_path = "mistral-7b-v0.1.Q4_K_M.gguf"
382
- if not os.path.exists(model_path):
383
- st.info("Model will be downloaded on first run")
384
-
385
- # Check if embeddings exist
386
- if not os.path.exists('embeddings_cache/embeddings.pkl'):
387
- st.info("Embeddings will be loaded from Drive")
388
-
389
- return True
390
-
391
- except Exception as e:
392
- logging.error(f"Space requirements check failed: {str(e)}")
393
- return False
394
-
395
  @st.cache_resource(show_spinner=False)
396
  def initialize_rag_pipeline():
397
  """Initialize the RAG pipeline once"""
398
  try:
399
- # First check/create necessary directories
400
- for directory in ['models', 'ESPN_data', 'embeddings_cache']:
401
- os.makedirs(directory, exist_ok=True)
402
-
403
- # Load embeddings from Drive first
404
  drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN"
405
  with st.spinner("Loading embeddings from Google Drive..."):
406
  cache_data = load_from_drive(drive_file_id)
@@ -408,7 +229,7 @@ def initialize_rag_pipeline():
408
  st.error("Failed to load embeddings from Google Drive")
409
  st.stop()
410
 
411
- # Now initialize pipeline
412
  data_folder = "ESPN_data"
413
  rag = RAGPipeline(data_folder)
414
 
@@ -426,20 +247,9 @@ def initialize_rag_pipeline():
426
  def main():
427
  try:
428
  # Environment check
429
- if not check_environment() or not check_space_requirements():
430
  return
431
 
432
- # Session state for initialization status
433
- if 'initialized' not in st.session_state:
434
- st.session_state.initialized = False
435
-
436
- # # Page config
437
- # st.set_page_config(
438
- # page_title="The Sport Chatbot",
439
- # page_icon="πŸ†",
440
- # layout="wide"
441
- # )
442
-
443
  # Improved CSS styling
444
  st.markdown("""
445
  <style>
@@ -510,7 +320,7 @@ def main():
510
  </style>
511
  """, unsafe_allow_html=True)
512
 
513
- # Header section with improved styling
514
  st.markdown("<h1 class='main-title'>πŸ† The Sport Chatbot</h1>", unsafe_allow_html=True)
515
  st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
516
  st.markdown("""
@@ -523,40 +333,22 @@ def main():
523
  </p>
524
  """, unsafe_allow_html=True)
525
 
526
- # Add some spacing
527
- st.markdown("<br>", unsafe_allow_html=True)
528
-
529
  # Initialize the pipeline
530
- if not st.session_state.initialized:
531
- try:
532
- with st.spinner("Loading resources..."):
533
- # Create necessary directories
534
- for directory in ['models', 'ESPN_data', 'embeddings_cache']:
535
- os.makedirs(directory, exist_ok=True)
536
-
537
- # Initialize RAG pipeline
538
- st.session_state.rag = initialize_rag_pipeline()
539
- st.session_state.initialized = True
540
-
541
- st.success("System initialized successfully!")
542
- except Exception as e:
543
- logging.error(f"Initialization error: {str(e)}")
544
- st.error("Unable to initialize the system. Please check if all required files are present.")
545
- st.stop()
546
 
547
- # Create columns for layout with golden ratio
548
  col1, col2, col3 = st.columns([1, 6, 1])
549
 
550
  with col2:
551
- # Query input with label styling
552
  query = st.text_input("What would you like to know about sports?")
553
 
554
- # Centered button
555
  if st.button("Get Answer"):
556
  if query:
557
  response_placeholder = st.empty()
558
  try:
559
- # Get response from RAG pipeline
560
  response = st.session_state.rag.process_query(query, response_placeholder)
561
  logging.info(f"Generated response: {response}")
562
  except Exception as e:
@@ -565,13 +357,12 @@ def main():
565
  else:
566
  st.warning("Please enter a question!")
567
 
568
- # Footer with improved styling
569
  st.markdown("<br><br>", unsafe_allow_html=True)
570
  st.markdown("---")
571
  st.markdown("""
572
  <p style='text-align: center; color: #666666; padding: 1rem 0;'>
573
- Powered by ESPN Data & Mistral AI πŸš€<br>
574
- <small>Running on Hugging Face Spaces</small>
575
  </p>
576
  """, unsafe_allow_html=True)
577
 
@@ -580,8 +371,4 @@ def main():
580
  st.error("An unexpected error occurred. Please check the logs and try again.")
581
 
582
  if __name__ == "__main__":
583
- try:
584
- main()
585
- except Exception as e:
586
- logging.error(f"Application error: {str(e)}")
587
- st.error("An unexpected error occurred. Please check the logs and try again.")
 
7
  import torch.nn.functional as F
8
  import re
9
  import requests
 
10
  from embedding_processor import SentenceTransformerRetriever, process_data
11
  import pickle
 
 
 
 
 
 
 
 
 
 
 
12
  import logging
13
  import sys
14
+ from llama_cpp import Llama
15
+ from tqdm import tqdm
16
 
17
+ # Set page config first
18
  st.set_page_config(
19
  page_title="The Sport Chatbot",
20
  page_icon="πŸ†",
 
28
  handlers=[logging.StreamHandler(sys.stdout)]
29
  )
30
 
31
+ def download_file_with_progress(url: str, filename: str):
32
+ """Download a file with progress bar using requests"""
33
+ response = requests.get(url, stream=True)
34
+ total_size = int(response.headers.get('content-length', 0))
35
+
36
+ with open(filename, 'wb') as file, tqdm(
37
+ desc=filename,
38
+ total=total_size,
39
+ unit='iB',
40
+ unit_scale=True,
41
+ unit_divisor=1024,
42
+ ) as progress_bar:
43
+ for data in response.iter_content(chunk_size=1024):
44
+ size = file.write(data)
45
+ progress_bar.update(size)
46
 
47
  @st.cache_data
48
  def load_from_drive(file_id: str):
 
67
  st.error(f"Error loading file from Drive: {str(e)}")
68
  return None
69
 
70
+ @st.cache_resource(show_spinner=False)
71
+ def load_llama_model():
72
+ """Load Llama model with caching"""
73
+ try:
74
+ model_path = "mistral-7b-v0.1.Q4_K_M.gguf"
75
+
76
+ if not os.path.exists(model_path):
77
+ st.info("Downloading model... This may take a while.")
78
+ direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
79
+ download_file_with_progress(direct_url, model_path)
80
+
81
+ llm_config = {
82
+ "model_path": model_path,
83
+ "n_ctx": 2048,
84
+ "n_threads": 4,
85
+ "n_batch": 512,
86
+ "n_gpu_layers": 0,
87
+ "verbose": False
88
+ }
89
+
90
+ model = Llama(**llm_config)
91
+ st.success("Model loaded successfully!")
92
+ return model
93
+ except Exception as e:
94
+ st.error(f"Error loading model: {str(e)}")
95
+ raise
96
 
97
+ def check_environment():
98
+ """Check if the environment is properly set up"""
99
+ try:
100
+ import torch
101
+ import sentence_transformers
102
+ return True
103
+ except ImportError as e:
104
+ st.error(f"Missing required package: {str(e)}")
105
+ st.stop()
106
+ return False
107
 
108
  class RAGPipeline:
 
109
  def __init__(self, data_folder: str, k: int = 5):
110
+ self.data_folder = data_folder
111
+ self.k = k
112
+ self.retriever = SentenceTransformerRetriever()
113
+ self.documents = []
114
+ self.device = torch.device("cpu")
115
+ self.llm = load_llama_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ def preprocess_query(self, query: str) -> str:
118
+ """Clean and prepare the query"""
119
+ query = query.lower().strip()
120
+ query = re.sub(r'\s+', ' ', query)
121
+ return query
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ def postprocess_response(self, response: str) -> str:
124
+ """Clean up the generated response"""
125
+ response = response.strip()
126
+ response = re.sub(r'\s+', ' ', response)
127
+ response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response)
128
+ return response
 
 
 
 
 
 
 
 
 
129
 
 
130
  def query_model(self, prompt: str) -> str:
131
+ """Query the local Llama model"""
132
  try:
133
  if self.llm is None:
134
  raise RuntimeError("Model not initialized")
135
+
 
136
  response = self.llm(
137
  prompt,
138
  max_tokens=512,
 
141
  echo=False,
142
  stop=["Question:", "\n\n"]
143
  )
144
+
 
145
  if response and 'choices' in response and len(response['choices']) > 0:
146
  text = response['choices'][0].get('text', '').strip()
147
  return text
148
  else:
149
  raise ValueError("No valid response generated")
150
+
151
  except Exception as e:
152
  logging.error(f"Error in query_model: {str(e)}")
153
  raise
154
+
 
 
 
 
 
155
  def process_query(self, query: str, placeholder) -> str:
156
  try:
157
  # Preprocess query
158
  query = self.preprocess_query(query)
159
+
160
  # Show retrieval status
161
  status = placeholder.empty()
162
  status.write("πŸ” Finding relevant information...")
163
+
164
  # Get embeddings and search
165
  query_embedding = self.retriever.encode([query])
166
  similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
167
  scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
168
+
169
  relevant_docs = [self.documents[idx] for idx in indices.tolist()]
170
+
171
  # Update status
172
  status.write("πŸ’­ Generating response...")
173
+
174
  # Prepare context and prompt
175
+ context = "\n".join(relevant_docs[:3])
176
  prompt = f"""Context information is below:
177
  {context}
178
+
179
  Given the context above, please answer the following question:
180
  {query}
181
 
 
185
  - Only include sports-related information
186
  - No dates or timestamps in the response
187
  - Use clear, natural language
188
+
189
  Answer:"""
190
+
191
  # Generate response
192
  response_placeholder = placeholder.empty()
193
+
194
  try:
195
  response_text = self.query_model(prompt)
196
  if response_text:
 
201
  message = "No relevant answer found. Please try rephrasing your question."
202
  response_placeholder.warning(message)
203
  return message
204
+
205
  except Exception as e:
206
  logging.error(f"Generation error: {str(e)}")
207
  message = "Had some trouble generating the response. Please try again."
208
  response_placeholder.warning(message)
209
  return message
210
+
211
  except Exception as e:
212
  logging.error(f"Process error: {str(e)}")
213
  message = "Something went wrong. Please try again with a different question."
214
  placeholder.warning(message)
215
  return message
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  @st.cache_resource(show_spinner=False)
218
  def initialize_rag_pipeline():
219
  """Initialize the RAG pipeline once"""
220
  try:
221
+ # Create necessary directories
222
+ os.makedirs("ESPN_data", exist_ok=True)
223
+
224
+ # Load embeddings from Drive
 
225
  drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN"
226
  with st.spinner("Loading embeddings from Google Drive..."):
227
  cache_data = load_from_drive(drive_file_id)
 
229
  st.error("Failed to load embeddings from Google Drive")
230
  st.stop()
231
 
232
+ # Initialize pipeline
233
  data_folder = "ESPN_data"
234
  rag = RAGPipeline(data_folder)
235
 
 
247
  def main():
248
  try:
249
  # Environment check
250
+ if not check_environment():
251
  return
252
 
 
 
 
 
 
 
 
 
 
 
 
253
  # Improved CSS styling
254
  st.markdown("""
255
  <style>
 
320
  </style>
321
  """, unsafe_allow_html=True)
322
 
323
+ # Header section
324
  st.markdown("<h1 class='main-title'>πŸ† The Sport Chatbot</h1>", unsafe_allow_html=True)
325
  st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
326
  st.markdown("""
 
333
  </p>
334
  """, unsafe_allow_html=True)
335
 
 
 
 
336
  # Initialize the pipeline
337
+ if 'rag' not in st.session_state:
338
+ with st.spinner("Loading resources..."):
339
+ st.session_state.rag = initialize_rag_pipeline()
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
+ # Create columns for layout
342
  col1, col2, col3 = st.columns([1, 6, 1])
343
 
344
  with col2:
345
+ # Query input
346
  query = st.text_input("What would you like to know about sports?")
347
 
 
348
  if st.button("Get Answer"):
349
  if query:
350
  response_placeholder = st.empty()
351
  try:
 
352
  response = st.session_state.rag.process_query(query, response_placeholder)
353
  logging.info(f"Generated response: {response}")
354
  except Exception as e:
 
357
  else:
358
  st.warning("Please enter a question!")
359
 
360
+ # Footer
361
  st.markdown("<br><br>", unsafe_allow_html=True)
362
  st.markdown("---")
363
  st.markdown("""
364
  <p style='text-align: center; color: #666666; padding: 1rem 0;'>
365
+ Powered by ESPN Data & Mistral AI πŸš€
 
366
  </p>
367
  """, unsafe_allow_html=True)
368
 
 
371
  st.error("An unexpected error occurred. Please check the logs and try again.")
372
 
373
  if __name__ == "__main__":
374
+ main()