Sport-Chatbot / app_modified.py
nishantgaurav23's picture
Rename app.py to app_modified.py
411e1ce verified
import os
import warnings
import logging
import sys
warnings.filterwarnings("ignore", category=UserWarning)
import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from typing import List, Callable, Dict, Optional, Any
import glob
from tqdm import tqdm
import pickle
import torch.nn.functional as F
from llama_cpp import Llama
import streamlit as st
import functools
from datetime import datetime
import re
import time
import requests
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
# Force CPU device
torch.device('cpu')
# Create necessary directories
for directory in ['models', 'ESPN_data', 'embeddings_cache']:
os.makedirs(directory, exist_ok=True)
# Logging configuration
LOGGING_CONFIG = {
'enabled': True,
'functions': {
'encode': True,
'store_embeddings': True,
'search': True,
'load_and_process_csvs': True,
'process_query': True
}
}
def download_file_with_progress(url: str, filename: str):
"""Download a file with progress bar using requests"""
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(filename, 'wb') as file, tqdm(
desc=filename,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
progress_bar.update(size)
def log_function(func: Callable) -> Callable:
"""Decorator to log function inputs and outputs"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not LOGGING_CONFIG['enabled'] or not LOGGING_CONFIG['functions'].get(func.__name__, False):
return func(*args, **kwargs)
if args and hasattr(args[0], '__class__'):
class_name = args[0].__class__.__name__
else:
class_name = func.__module__
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
log_args = args[1:] if class_name != func.__module__ else args
def format_arg(arg):
if isinstance(arg, torch.Tensor):
return f"Tensor(shape={list(arg.shape)}, device={arg.device})"
elif isinstance(arg, list):
return f"List(len={len(arg)})"
elif isinstance(arg, str) and len(arg) > 100:
return f"String(len={len(arg)}): {arg[:100]}..."
return arg
formatted_args = [format_arg(arg) for arg in log_args]
formatted_kwargs = {k: format_arg(v) for k, v in kwargs.items()}
print(f"\n{'='*80}")
print(f"[{timestamp}] FUNCTION CALL: {class_name}.{func.__name__}")
print(f"INPUTS:")
print(f" args: {formatted_args}")
print(f" kwargs: {formatted_kwargs}")
result = func(*args, **kwargs)
formatted_result = format_arg(result)
print(f"OUTPUT:")
print(f" {formatted_result}")
print(f"{'='*80}\n")
return result
return wrapper
def check_environment():
"""Check if the environment is properly set up"""
try:
import numpy as np
import torch
import sentence_transformers
import llama_cpp
return True
except ImportError as e:
st.error(f"Missing required package: {str(e)}")
st.stop()
return False
class SentenceTransformerRetriever:
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
self.device = torch.device("cpu")
self.model_name = model_name
self.cache_dir = cache_dir
self.cache_file = "embeddings.pkl"
self.doc_embeddings = None
os.makedirs(cache_dir, exist_ok=True)
self.model = self._load_model(model_name)
@st.cache_resource(show_spinner=False)
def _load_model(_self, _model_name: str):
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model = SentenceTransformer(_model_name, device="cpu")
test_embedding = model.encode("test", convert_to_tensor=True)
if not isinstance(test_embedding, torch.Tensor):
raise ValueError("Model initialization failed")
return model
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
raise
def get_cache_path(self, data_folder: str = None) -> str:
return os.path.join(self.cache_dir, self.cache_file)
@log_function
def save_cache(self, data_folder: str, cache_data: dict):
try:
cache_path = self.get_cache_path()
if os.path.exists(cache_path):
os.remove(cache_path)
with open(cache_path, 'wb') as f:
pickle.dump(cache_data, f)
logging.info(f"Cache saved at: {cache_path}")
except Exception as e:
logging.error(f"Error saving cache: {str(e)}")
raise
@log_function
@st.cache_data
def load_cache(_self, _data_folder: str = None) -> Optional[Dict]:
try:
cache_path = _self.get_cache_path()
if os.path.exists(cache_path):
with open(cache_path, 'rb') as f:
logging.info(f"Loading cache from: {cache_path}")
cache_data = pickle.load(f)
if isinstance(cache_data, dict) and 'embeddings' in cache_data and 'documents' in cache_data:
return cache_data
logging.warning("Invalid cache format")
return None
except Exception as e:
logging.error(f"Error loading cache: {str(e)}")
return None
@log_function
def encode(self, texts: List[str], batch_size: int = 64) -> torch.Tensor: # Increased batch size
try:
# Show a Streamlit progress bar
progress_text = "Processing documents..."
progress_bar = st.progress(0)
total_batches = len(texts) // batch_size + (1 if len(texts) % batch_size != 0 else 0)
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_embeddings = self.model.encode(
batch,
convert_to_tensor=True,
show_progress_bar=False # Disable tqdm progress bar
)
all_embeddings.append(batch_embeddings)
# Update progress
progress = min((i + batch_size) / len(texts), 1.0)
progress_bar.progress(progress)
# Clear progress bar
progress_bar.empty()
# Concatenate all embeddings
embeddings = torch.cat(all_embeddings, dim=0)
return F.normalize(embeddings, p=2, dim=1)
except Exception as e:
logging.error(f"Error encoding texts: {str(e)}")
raise
@log_function
def store_embeddings(self, embeddings: torch.Tensor):
self.doc_embeddings = embeddings
@log_function
def search(self, query_embedding: torch.Tensor, k: int, documents: List[str]):
try:
if self.doc_embeddings is None:
raise ValueError("No document embeddings stored!")
similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
k = min(k, len(documents))
scores, indices = torch.topk(similarities, k=k)
logging.info(f"\nSimilarity Stats:")
logging.info(f"Max similarity: {similarities.max().item():.4f}")
logging.info(f"Mean similarity: {similarities.mean().item():.4f}")
logging.info(f"Selected similarities: {scores.tolist()}")
return indices.cpu(), scores.cpu()
except Exception as e:
logging.error(f"Error in search: {str(e)}")
raise
class RAGPipeline:
def __init__(self, data_folder: str, k: int = 5):
self.data_folder = data_folder
self.k = k
self.retriever = SentenceTransformerRetriever()
self.documents = []
self.device = torch.device("cpu")
# Change 1: Process documents first
self.load_and_process_csvs()
# Change 2: Simplified model path
self.model_path = "mistral-7b-v0.1.Q4_K_M.gguf"
self.llm = None
# Change 3: Initialize model after documents are processed
self._initialize_model()
@st.cache_resource # Added caching decorator
def _initialize_model(_self):
try:
if not os.path.exists(_self.model_path):
direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
download_file_with_progress(direct_url, _self.model_path)
# Added better error handling
if not os.path.exists(_self.model_path):
raise FileNotFoundError(f"Model file {_self.model_path} not found after download attempts")
# Added verbose mode for better debugging
llm_config = {
"n_ctx": 2048,
"n_threads": 4,
"n_batch": 512,
"n_gpu_layers": 0,
"verbose": True # Added this
}
_self.llm = Llama(model_path=_self.model_path, **llm_config)
st.success("Model loaded successfully!")
except Exception as e:
# Added better error logging
logging.error(f"Error initializing model: {str(e)}")
st.error(f"Error initializing model: {str(e)}")
raise
def check_model_health(self):
try:
if self.llm is None:
return False
test_response = self.llm(
"Test prompt",
max_tokens=10,
temperature=0.4,
echo=False
)
return isinstance(test_response, dict) and 'choices' in test_response
except Exception:
return False
@log_function
@st.cache_data
def load_and_process_csvs(_self):
try:
# Try loading from cache first
cache_data = _self.retriever.load_cache(_self.data_folder)
if cache_data is not None:
_self.documents = cache_data['documents']
_self.retriever.store_embeddings(cache_data['embeddings'])
st.success("Loaded documents from cache")
return
st.info("Processing documents... This may take a while.")
csv_files = glob.glob(os.path.join(_self.data_folder, "*.csv"))
if not csv_files:
raise FileNotFoundError(f"No CSV files found in {_self.data_folder}")
all_documents = []
total_files = len(csv_files)
# Create a progress bar
progress_bar = st.progress(0)
for idx, csv_file in enumerate(csv_files):
try:
df = pd.read_csv(csv_file, low_memory=False) # Added low_memory=False
texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist()
all_documents.extend(texts)
# Update progress
progress = (idx + 1) / total_files
progress_bar.progress(progress)
except Exception as e:
logging.error(f"Error processing file {csv_file}: {e}")
continue
# Clear progress bar
progress_bar.empty()
if not all_documents:
raise ValueError("No documents were successfully loaded")
st.info(f"Processing {len(all_documents)} documents...")
_self.documents = all_documents
embeddings = _self.retriever.encode(all_documents)
_self.retriever.store_embeddings(embeddings)
# Save to cache
cache_data = {
'embeddings': embeddings,
'documents': _self.documents
}
_self.retriever.save_cache(_self.data_folder, cache_data)
st.success("Document processing complete!")
except Exception as e:
logging.error(f"Error in load_and_process_csvs: {str(e)}")
raise
def preprocess_query(self, query: str) -> str:
query = query.lower().strip()
query = re.sub(r'\s+', ' ', query)
return query
def postprocess_response(self, response: str) -> str:
response = response.strip()
response = re.sub(r'\s+', ' ', response)
response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response)
return response
@log_function
def process_query(self, query: str, placeholder) -> str:
try:
if self.llm is None:
raise RuntimeError("LLM model not initialized")
if self.retriever.model is None:
raise RuntimeError("Sentence transformer model not initialized")
query = self.preprocess_query(query)
status = placeholder.empty()
status.write("πŸ” Finding relevant information...")
query_embedding = self.retriever.encode([query])
indices, scores = self.retriever.search(query_embedding, self.k, self.documents)
logging.info("\nSearch Results:")
for idx, score in zip(indices.tolist(), scores.tolist()):
logging.info(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...")
relevant_docs = [self.documents[idx] for idx in indices.tolist()]
status.write("πŸ’­ Generating response...")
context = "\n".join(relevant_docs)
prompt = f"""Context information is below:
{context}
Given the context above, please answer the following question:
{query}
Guidelines:
- If you cannot answer based on the context, say so politely
- Keep the response concise and focused
- Only include sports-related information
- No dates or timestamps in the response
- Use clear, natural language
Answer:"""
response_placeholder = placeholder.empty()
try:
response = self.llm(
prompt,
max_tokens=512,
temperature=0.4,
top_p=0.95,
echo=False,
stop=["Question:", "\n\n"]
)
if response and 'choices' in response and len(response['choices']) > 0:
generated_text = response['choices'][0].get('text', '').strip()
if generated_text:
final_response = self.postprocess_response(generated_text)
response_placeholder.markdown(final_response)
return final_response
else:
message = "No relevant answer found. Please try rephrasing your question."
response_placeholder.warning(message)
return message
else:
message = "Unable to generate response. Please try again."
response_placeholder.warning(message)
return message
except Exception as e:
logging.error(f"Generation error: {str(e)}")
message = "Had some trouble generating the response. Please try again."
response_placeholder.warning(message)
return message
except Exception as e:
logging.error(f"Process error: {str(e)}")
message = "Something went wrong. Please try again with a different question."
placeholder.warning(message)
return message
@st.cache_resource(show_spinner=False)
def initialize_rag_pipeline():
"""Initialize the RAG pipeline once"""
try:
data_folder = "ESPN_data"
if not os.path.exists(data_folder):
os.makedirs(data_folder, exist_ok=True)
# Check for cache
cache_path = os.path.join("embeddings_cache", "embeddings.pkl")
if os.path.exists(cache_path):
st.info("Found cached data. Loading...")
else:
st.warning("Initial setup may take several minutes...")
rag = RAGPipeline(data_folder)
return rag
except Exception as e:
logging.error(f"Pipeline initialization error: {str(e)}")
st.error("Failed to initialize the system. Please check if all required files are present.")
raise
def main():
try:
# Environment check
if not check_environment():
return
# Page config
st.set_page_config(
page_title="The Sport Chatbot",
page_icon="πŸ†",
layout="wide"
)
# Improved CSS styling
st.markdown("""
<style>
.block-container {
padding-top: 2rem;
padding-bottom: 2rem;
}
.stTextInput > div > div > input {
width: 100%;
}
.stButton > button {
width: 200px;
margin: 0 auto;
display: block;
background-color: #FF4B4B;
color: white;
border-radius: 5px;
padding: 0.5rem 1rem;
}
.main-title {
text-align: center;
padding: 1rem 0;
font-size: 3rem;
color: #1F1F1F;
}
.sub-title {
text-align: center;
padding: 0.5rem 0;
font-size: 1.5rem;
color: #4F4F4F;
}
.description {
text-align: center;
color: #666666;
padding: 0.5rem 0;
font-size: 1.1rem;
line-height: 1.6;
margin-bottom: 1rem;
}
.stMarkdown {
max-width: 100%;
}
.st-emotion-cache-16idsys p {
font-size: 1.1rem;
line-height: 1.6;
}
.main-content {
max-width: 1200px;
margin: 0 auto;
padding: 0 1rem;
}
</style>
""", unsafe_allow_html=True)
# Header section
st.markdown("<h1 class='main-title'>πŸ† The Sport Chatbot</h1>", unsafe_allow_html=True)
st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True)
st.markdown("""
<p class='description'>
Hey there! πŸ‘‹ I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball.
With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024.
</p>
<p class='description'>
Got any general questions? Feel free to askβ€”I'll do my best to provide answers based on the information I've been trained on!
</p>
""", unsafe_allow_html=True)
# Add spacing
st.markdown("<br>", unsafe_allow_html=True)
# Initialize the pipeline
try:
with st.spinner("Loading resources..."):
rag = initialize_rag_pipeline()
# Add a model health check
if not rag.check_model_health():
st.error("Model initialization failed. Please try restarting the application.")
return
except Exception as e:
logging.error(f"Initialization error: {str(e)}")
st.error("Unable to initialize the system. Please check if all required files are present.")
return
# Create columns for layout with golden ratio
col1, col2, col3 = st.columns([1, 6, 1])
with col2:
# Query input with label styling
query = st.text_input("What would you like to know about sports?")
# Centered button
if st.button("Get Answer"):
if query:
response_placeholder = st.empty()
try:
response = rag.process_query(query, response_placeholder)
logging.info(f"Generated response: {response}")
except Exception as e:
logging.error(f"Query processing error: {str(e)}")
response_placeholder.warning("Unable to process your question. Please try again.")
else:
st.warning("Please enter a question!")
# Footer
st.markdown("<br><br>", unsafe_allow_html=True)
st.markdown("---")
st.markdown("""
<p style='text-align: center; color: #666666; padding: 1rem 0;'>
Powered by ESPN Data & Mistral AI πŸš€
</p>
""", unsafe_allow_html=True)
except Exception as e:
logging.error(f"Application error: {str(e)}")
st.error("An unexpected error occurred. Please check the logs and try again.")
if __name__ == "__main__":
try:
main()
except Exception as e:
logging.error(f"Application error: {str(e)}")
st.error("An unexpected error occurred. Please check the logs and try again.")