nishantgaurav23 commited on
Commit
ea8d2d0
Β·
verified Β·
1 Parent(s): 08ce49b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +384 -0
app.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ warnings.filterwarnings("ignore", category=UserWarning)
4
+
5
+ import streamlit as st
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import re
9
+ import requests
10
+ from dotenv import load_dotenv
11
+ from embedding_processor import SentenceTransformerRetriever, process_data
12
+ import pickle
13
+
14
+ import os
15
+ import warnings
16
+ import json # Add this import
17
+
18
+
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+
23
+ # Add the new function here, right after imports and before API configuration
24
+ @st.cache_data
25
+ @st.cache_data
26
+ def load_from_drive(file_id: str):
27
+ """Load pickle file directly from Google Drive"""
28
+ try:
29
+ # Direct download URL for Google Drive
30
+ url = f"https://drive.google.com/uc?id={file_id}&export=download"
31
+
32
+ # First request to get the confirmation token
33
+ session = requests.Session()
34
+ response = session.get(url, stream=True)
35
+
36
+ # Check if we need to confirm download
37
+ for key, value in response.cookies.items():
38
+ if key.startswith('download_warning'):
39
+ # Add confirmation parameter to the URL
40
+ url = f"{url}&confirm={value}"
41
+ response = session.get(url, stream=True)
42
+ break
43
+
44
+ # Load the content and convert to pickle
45
+ content = response.content
46
+ print(f"Successfully downloaded {len(content)} bytes")
47
+ return pickle.loads(content)
48
+
49
+ except Exception as e:
50
+ print(f"Detailed error: {str(e)}") # This will help debug
51
+ st.error(f"Error loading file from Drive: {str(e)}")
52
+ return None
53
+
54
+ # Hugging Face API configuration
55
+
56
+ API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
57
+ headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"}
58
+
59
+
60
+ class RAGPipeline:
61
+
62
+ def __init__(self, data_folder: str, k: int = 3): # Reduced k for faster retrieval
63
+ self.data_folder = data_folder
64
+ self.k = k
65
+ self.retriever = SentenceTransformerRetriever()
66
+ cache_data = process_data(data_folder)
67
+ self.documents = cache_data['documents']
68
+ self.retriever.store_embeddings(cache_data['embeddings'])
69
+
70
+
71
+ # Alternative API call with streaming
72
+ def query_model(self, payload):
73
+ """Query the Hugging Face API with streaming"""
74
+ try:
75
+ # Add streaming parameters
76
+ payload["parameters"]["stream"] = True
77
+
78
+ response = requests.post(
79
+ API_URL,
80
+ headers=headers,
81
+ json=payload,
82
+ stream=True
83
+ )
84
+ response.raise_for_status()
85
+
86
+ # Collect the entire response
87
+ full_response = ""
88
+ for line in response.iter_lines():
89
+ if line:
90
+ try:
91
+ json_response = json.loads(line)
92
+ if isinstance(json_response, list) and len(json_response) > 0:
93
+ chunk_text = json_response[0].get('generated_text', '')
94
+ if chunk_text:
95
+ full_response += chunk_text
96
+ except json.JSONDecodeError as e:
97
+ print(f"Error decoding JSON: {e}")
98
+ continue
99
+
100
+ return [{"generated_text": full_response}]
101
+
102
+ except requests.exceptions.RequestException as e:
103
+ print(f"API request failed: {str(e)}")
104
+ raise
105
+
106
+ def preprocess_query(self, query: str) -> str:
107
+ """Clean and prepare the query"""
108
+ query = query.lower().strip()
109
+ query = re.sub(r'\s+', ' ', query)
110
+ return query
111
+
112
+ def postprocess_response(self, response: str) -> str:
113
+ """Clean up the generated response"""
114
+ response = response.strip()
115
+ response = re.sub(r'\s+', ' ', response)
116
+ response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response)
117
+ return response
118
+
119
+
120
+ def process_query(self, query: str, placeholder) -> str:
121
+ try:
122
+ # Preprocess query
123
+ query = self.preprocess_query(query)
124
+
125
+ # Show retrieval status
126
+ status = placeholder.empty()
127
+ status.write("πŸ” Finding relevant information...")
128
+
129
+ # Get embeddings and search using tensor operations
130
+ query_embedding = self.retriever.encode([query])
131
+ similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings)
132
+ scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents)))
133
+
134
+ # Print search results for debugging
135
+ print("\nSearch Results:")
136
+ for idx, score in zip(indices.tolist(), scores.tolist()):
137
+ print(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...")
138
+
139
+ relevant_docs = [self.documents[idx] for idx in indices.tolist()]
140
+
141
+ # Update status
142
+ status.write("πŸ’­ Generating response...")
143
+
144
+ # Prepare context and prompt
145
+ context = "\n".join(relevant_docs[:3]) # Only use top 3 most relevant docs
146
+ prompt = f"""Answer this question using the given context. Be specific and detailed.
147
+
148
+ Context: {context}
149
+
150
+ Question: {query}
151
+
152
+ Answer (provide a complete, detailed response):"""
153
+
154
+ # Generate response
155
+ response_placeholder = placeholder.empty()
156
+
157
+ try:
158
+ response = requests.post(
159
+ API_URL,
160
+ headers=headers,
161
+ json={
162
+ "inputs": prompt,
163
+ "parameters": {
164
+ "max_new_tokens": 1024,
165
+ "temperature": 0.5,
166
+ "top_p": 0.9,
167
+ "top_k": 50,
168
+ "repetition_penalty": 1.03,
169
+ "do_sample": True
170
+ }
171
+ },
172
+ timeout=30
173
+ ).json()
174
+
175
+ if response and isinstance(response, list) and len(response) > 0:
176
+ generated_text = response[0].get('generated_text', '').strip()
177
+ if generated_text:
178
+ # Find and extract only the answer part
179
+ if "Answer:" in generated_text:
180
+ answer_part = generated_text.split("Answer:")[-1].strip()
181
+ elif "Answer (provide a complete, detailed response):" in generated_text:
182
+ answer_part = generated_text.split("Answer (provide a complete, detailed response):")[-1].strip()
183
+ else:
184
+ answer_part = generated_text.strip()
185
+
186
+ # Clean up the answer
187
+ answer_part = answer_part.replace("Context:", "").replace("Question:", "")
188
+
189
+ final_response = self.postprocess_response(answer_part)
190
+ response_placeholder.markdown(final_response)
191
+ return final_response
192
+
193
+ message = "No relevant answer found. Please try rephrasing your question."
194
+ response_placeholder.warning(message)
195
+ return message
196
+
197
+ except Exception as e:
198
+ print(f"Generation error: {str(e)}")
199
+ message = "Had some trouble generating the response. Please try again."
200
+ response_placeholder.warning(message)
201
+ return message
202
+
203
+ except Exception as e:
204
+ print(f"Process error: {str(e)}")
205
+ message = "Something went wrong. Please try again with a different question."
206
+ placeholder.warning(message)
207
+ return message
208
+ def check_environment():
209
+ """Check if the environment is properly set up"""
210
+ if not headers['Authorization']:
211
+ st.error("HUGGINGFACE_API_KEY environment variable not set!")
212
+ st.stop()
213
+ return False
214
+
215
+ try:
216
+ import torch
217
+ import sentence_transformers
218
+ return True
219
+ except ImportError as e:
220
+ st.error(f"Missing required package: {str(e)}")
221
+ st.stop()
222
+ return False
223
+
224
+ # @st.cache_resource
225
+ # def initialize_rag_pipeline():
226
+ # """Initialize the RAG pipeline once"""
227
+ # data_folder = "ESPN_data"
228
+ # return RAGPipeline(data_folder)
229
+
230
+ @st.cache_resource
231
+ def initialize_rag_pipeline():
232
+ """Initialize the RAG pipeline once"""
233
+ data_folder = "ESPN_data"
234
+ drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN"
235
+
236
+ with st.spinner("Loading embeddings from Google Drive..."):
237
+ cache_data = load_from_drive(drive_file_id)
238
+ if cache_data is None:
239
+ st.error("Failed to load embeddings from Google Drive")
240
+ st.stop()
241
+
242
+ rag = RAGPipeline(data_folder)
243
+ rag.documents = cache_data['documents']
244
+ rag.retriever.store_embeddings(cache_data['embeddings'])
245
+ return rag
246
+
247
+ def main():
248
+ # Environment check
249
+ if not check_environment():
250
+ return
251
+
252
+ # Page config
253
+ st.set_page_config(
254
+ page_title="The Sport Chatbot",
255
+ page_icon="πŸ†",
256
+ layout="wide"
257
+ )
258
+
259
+ # Improved CSS styling
260
+ st.markdown("""
261
+ <style>
262
+ /* Container styling */
263
+ .block-container {
264
+ padding-top: 2rem;
265
+ padding-bottom: 2rem;
266
+ }
267
+
268
+ /* Text input styling */
269
+ .stTextInput > div > div > input {
270
+ width: 100%;
271
+ }
272
+
273
+ /* Button styling */
274
+ .stButton > button {
275
+ width: 200px;
276
+ margin: 0 auto;
277
+ display: block;
278
+ background-color: #FF4B4B;
279
+ color: white;
280
+ border-radius: 5px;
281
+ padding: 0.5rem 1rem;
282
+ }
283
+
284
+ /* Title styling */
285
+ .main-title {
286
+ text-align: center;
287
+ padding: 1rem 0;
288
+ font-size: 3rem;
289
+ color: #1F1F1F;
290
+ }
291
+
292
+ .sub-title {
293
+ text-align: center;
294
+ padding: 0.5rem 0;
295
+ font-size: 1.5rem;
296
+ color: #4F4F4F;
297
+ }
298
+
299
+ /* Description styling */
300
+ .description {
301
+ text-align: center;
302
+ color: #666666;
303
+ padding: 0.5rem 0;
304
+ font-size: 1.1rem;
305
+ line-height: 1.6;
306
+ margin-bottom: 1rem;
307
+ }
308
+
309
+ /* Answer container styling */
310
+ .stMarkdown {
311
+ max-width: 100%;
312
+ }
313
+
314
+ /* Streamlit default overrides */
315
+ .st-emotion-cache-16idsys p {
316
+ font-size: 1.1rem;
317
+ line-height: 1.6;
318
+ }
319
+
320
+ /* Container for main content */
321
+ .main-content {
322
+ max-width: 1200px;
323
+ margin: 0 auto;
324
+ padding: 0 1rem;
325
+ }
326
+ </style>
327
+ """, unsafe_allow_html=True)
328
+
329
+ # Header section with improved styling
330
+ st.markdown("<h1 class='main-title'>πŸ† The Sport Chatbot</h1>", unsafe_allow_html=True)
331
+ st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
332
+ st.markdown("""
333
+ <p class='description'>
334
+ Hey there! πŸ‘‹ I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball.
335
+ With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.
336
+ </p>
337
+ <p class='description'>
338
+ Got any general questions? Feel free to askβ€”I'll do my best to provide answers based on the information I've been trained on!
339
+ </p>
340
+ """, unsafe_allow_html=True)
341
+
342
+ # Add some spacing
343
+ st.markdown("<br>", unsafe_allow_html=True)
344
+
345
+ # Initialize the pipeline
346
+ try:
347
+ with st.spinner("Loading resources..."):
348
+ rag = initialize_rag_pipeline()
349
+ except Exception as e:
350
+ print(f"Initialization error: {str(e)}")
351
+ st.error("Unable to initialize the system. Please check if all required files are present.")
352
+ st.stop()
353
+
354
+ # Create columns for layout with golden ratio
355
+ col1, col2, col3 = st.columns([1, 6, 1])
356
+
357
+ with col2:
358
+ # Query input with label styling
359
+ query = st.text_input("What would you like to know about sports?")
360
+
361
+ # Centered button
362
+ if st.button("Get Answer"):
363
+ if query:
364
+ response_placeholder = st.empty()
365
+ try:
366
+ response = rag.process_query(query, response_placeholder)
367
+ print(f"Generated response: {response}")
368
+ except Exception as e:
369
+ print(f"Query processing error: {str(e)}")
370
+ response_placeholder.warning("Unable to process your question. Please try again.")
371
+ else:
372
+ st.warning("Please enter a question!")
373
+
374
+ # Footer with improved styling
375
+ st.markdown("<br><br>", unsafe_allow_html=True)
376
+ st.markdown("---")
377
+ st.markdown("""
378
+ <p style='text-align: center; color: #666666; padding: 1rem 0;'>
379
+ Powered by ESPN Data & Mistral AI πŸš€
380
+ </p>
381
+ """, unsafe_allow_html=True)
382
+
383
+ if __name__ == "__main__":
384
+ main()