Spaces:
Runtime error
Runtime error
import os | |
import warnings | |
import logging | |
import sys | |
warnings.filterwarnings("ignore", category=UserWarning) | |
import numpy as np | |
import pandas as pd | |
import torch | |
from sentence_transformers import SentenceTransformer | |
from typing import List, Callable, Dict, Optional, Any | |
import glob | |
from tqdm import tqdm | |
import pickle | |
import torch.nn.functional as F | |
from llama_cpp import Llama | |
import streamlit as st | |
import functools | |
from datetime import datetime | |
import re | |
import time | |
import requests | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler(sys.stdout)] | |
) | |
# Force CPU device | |
torch.device('cpu') | |
# Create necessary directories | |
for directory in ['models', 'ESPN_data', 'embeddings_cache']: | |
os.makedirs(directory, exist_ok=True) | |
# Logging configuration | |
LOGGING_CONFIG = { | |
'enabled': True, | |
'functions': { | |
'encode': True, | |
'store_embeddings': True, | |
'search': True, | |
'load_and_process_csvs': True, | |
'process_query': True | |
} | |
} | |
def download_file_with_progress(url: str, filename: str): | |
"""Download a file with progress bar using requests""" | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get('content-length', 0)) | |
with open(filename, 'wb') as file, tqdm( | |
desc=filename, | |
total=total_size, | |
unit='iB', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as progress_bar: | |
for data in response.iter_content(chunk_size=1024): | |
size = file.write(data) | |
progress_bar.update(size) | |
def log_function(func: Callable) -> Callable: | |
"""Decorator to log function inputs and outputs""" | |
def wrapper(*args, **kwargs): | |
if not LOGGING_CONFIG['enabled'] or not LOGGING_CONFIG['functions'].get(func.__name__, False): | |
return func(*args, **kwargs) | |
if args and hasattr(args[0], '__class__'): | |
class_name = args[0].__class__.__name__ | |
else: | |
class_name = func.__module__ | |
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') | |
log_args = args[1:] if class_name != func.__module__ else args | |
def format_arg(arg): | |
if isinstance(arg, torch.Tensor): | |
return f"Tensor(shape={list(arg.shape)}, device={arg.device})" | |
elif isinstance(arg, list): | |
return f"List(len={len(arg)})" | |
elif isinstance(arg, str) and len(arg) > 100: | |
return f"String(len={len(arg)}): {arg[:100]}..." | |
return arg | |
formatted_args = [format_arg(arg) for arg in log_args] | |
formatted_kwargs = {k: format_arg(v) for k, v in kwargs.items()} | |
print(f"\n{'='*80}") | |
print(f"[{timestamp}] FUNCTION CALL: {class_name}.{func.__name__}") | |
print(f"INPUTS:") | |
print(f" args: {formatted_args}") | |
print(f" kwargs: {formatted_kwargs}") | |
result = func(*args, **kwargs) | |
formatted_result = format_arg(result) | |
print(f"OUTPUT:") | |
print(f" {formatted_result}") | |
print(f"{'='*80}\n") | |
return result | |
return wrapper | |
def check_environment(): | |
"""Check if the environment is properly set up""" | |
try: | |
import numpy as np | |
import torch | |
import sentence_transformers | |
import llama_cpp | |
return True | |
except ImportError as e: | |
st.error(f"Missing required package: {str(e)}") | |
st.stop() | |
return False | |
class SentenceTransformerRetriever: | |
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"): | |
self.device = torch.device("cpu") | |
self.model_name = model_name | |
self.cache_dir = cache_dir | |
self.cache_file = "embeddings.pkl" | |
self.doc_embeddings = None | |
os.makedirs(cache_dir, exist_ok=True) | |
self.model = self._load_model(model_name) | |
def _load_model(_self, _model_name: str): | |
try: | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
model = SentenceTransformer(_model_name, device="cpu") | |
test_embedding = model.encode("test", convert_to_tensor=True) | |
if not isinstance(test_embedding, torch.Tensor): | |
raise ValueError("Model initialization failed") | |
return model | |
except Exception as e: | |
logging.error(f"Error loading model: {str(e)}") | |
raise | |
def get_cache_path(self, data_folder: str = None) -> str: | |
return os.path.join(self.cache_dir, self.cache_file) | |
def save_cache(self, data_folder: str, cache_data: dict): | |
try: | |
cache_path = self.get_cache_path() | |
if os.path.exists(cache_path): | |
os.remove(cache_path) | |
with open(cache_path, 'wb') as f: | |
pickle.dump(cache_data, f) | |
logging.info(f"Cache saved at: {cache_path}") | |
except Exception as e: | |
logging.error(f"Error saving cache: {str(e)}") | |
raise | |
def load_cache(_self, _data_folder: str = None) -> Optional[Dict]: | |
try: | |
cache_path = _self.get_cache_path() | |
if os.path.exists(cache_path): | |
with open(cache_path, 'rb') as f: | |
logging.info(f"Loading cache from: {cache_path}") | |
cache_data = pickle.load(f) | |
if isinstance(cache_data, dict) and 'embeddings' in cache_data and 'documents' in cache_data: | |
return cache_data | |
logging.warning("Invalid cache format") | |
return None | |
except Exception as e: | |
logging.error(f"Error loading cache: {str(e)}") | |
return None | |
def encode(self, texts: List[str], batch_size: int = 64) -> torch.Tensor: # Increased batch size | |
try: | |
# Show a Streamlit progress bar | |
progress_text = "Processing documents..." | |
progress_bar = st.progress(0) | |
total_batches = len(texts) // batch_size + (1 if len(texts) % batch_size != 0 else 0) | |
all_embeddings = [] | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i:i + batch_size] | |
batch_embeddings = self.model.encode( | |
batch, | |
convert_to_tensor=True, | |
show_progress_bar=False # Disable tqdm progress bar | |
) | |
all_embeddings.append(batch_embeddings) | |
# Update progress | |
progress = min((i + batch_size) / len(texts), 1.0) | |
progress_bar.progress(progress) | |
# Clear progress bar | |
progress_bar.empty() | |
# Concatenate all embeddings | |
embeddings = torch.cat(all_embeddings, dim=0) | |
return F.normalize(embeddings, p=2, dim=1) | |
except Exception as e: | |
logging.error(f"Error encoding texts: {str(e)}") | |
raise | |
def store_embeddings(self, embeddings: torch.Tensor): | |
self.doc_embeddings = embeddings | |
def search(self, query_embedding: torch.Tensor, k: int, documents: List[str]): | |
try: | |
if self.doc_embeddings is None: | |
raise ValueError("No document embeddings stored!") | |
similarities = F.cosine_similarity(query_embedding, self.doc_embeddings) | |
k = min(k, len(documents)) | |
scores, indices = torch.topk(similarities, k=k) | |
logging.info(f"\nSimilarity Stats:") | |
logging.info(f"Max similarity: {similarities.max().item():.4f}") | |
logging.info(f"Mean similarity: {similarities.mean().item():.4f}") | |
logging.info(f"Selected similarities: {scores.tolist()}") | |
return indices.cpu(), scores.cpu() | |
except Exception as e: | |
logging.error(f"Error in search: {str(e)}") | |
raise | |
class RAGPipeline: | |
def __init__(self, data_folder: str, k: int = 5): | |
self.data_folder = data_folder | |
self.k = k | |
self.retriever = SentenceTransformerRetriever() | |
self.documents = [] | |
self.device = torch.device("cpu") | |
# Change 1: Process documents first | |
self.load_and_process_csvs() | |
# Change 2: Simplified model path | |
self.model_path = "mistral-7b-v0.1.Q4_K_M.gguf" | |
self.llm = None | |
# Change 3: Initialize model after documents are processed | |
self._initialize_model() | |
# Added caching decorator | |
def _initialize_model(_self): | |
try: | |
if not os.path.exists(_self.model_path): | |
direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf" | |
download_file_with_progress(direct_url, _self.model_path) | |
# Added better error handling | |
if not os.path.exists(_self.model_path): | |
raise FileNotFoundError(f"Model file {_self.model_path} not found after download attempts") | |
# Added verbose mode for better debugging | |
llm_config = { | |
"n_ctx": 2048, | |
"n_threads": 4, | |
"n_batch": 512, | |
"n_gpu_layers": 0, | |
"verbose": True # Added this | |
} | |
_self.llm = Llama(model_path=_self.model_path, **llm_config) | |
st.success("Model loaded successfully!") | |
except Exception as e: | |
# Added better error logging | |
logging.error(f"Error initializing model: {str(e)}") | |
st.error(f"Error initializing model: {str(e)}") | |
raise | |
def check_model_health(self): | |
try: | |
if self.llm is None: | |
return False | |
test_response = self.llm( | |
"Test prompt", | |
max_tokens=10, | |
temperature=0.4, | |
echo=False | |
) | |
return isinstance(test_response, dict) and 'choices' in test_response | |
except Exception: | |
return False | |
def load_and_process_csvs(_self): | |
try: | |
# Try loading from cache first | |
cache_data = _self.retriever.load_cache(_self.data_folder) | |
if cache_data is not None: | |
_self.documents = cache_data['documents'] | |
_self.retriever.store_embeddings(cache_data['embeddings']) | |
st.success("Loaded documents from cache") | |
return | |
st.info("Processing documents... This may take a while.") | |
csv_files = glob.glob(os.path.join(_self.data_folder, "*.csv")) | |
if not csv_files: | |
raise FileNotFoundError(f"No CSV files found in {_self.data_folder}") | |
all_documents = [] | |
total_files = len(csv_files) | |
# Create a progress bar | |
progress_bar = st.progress(0) | |
for idx, csv_file in enumerate(csv_files): | |
try: | |
df = pd.read_csv(csv_file, low_memory=False) # Added low_memory=False | |
texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist() | |
all_documents.extend(texts) | |
# Update progress | |
progress = (idx + 1) / total_files | |
progress_bar.progress(progress) | |
except Exception as e: | |
logging.error(f"Error processing file {csv_file}: {e}") | |
continue | |
# Clear progress bar | |
progress_bar.empty() | |
if not all_documents: | |
raise ValueError("No documents were successfully loaded") | |
st.info(f"Processing {len(all_documents)} documents...") | |
_self.documents = all_documents | |
embeddings = _self.retriever.encode(all_documents) | |
_self.retriever.store_embeddings(embeddings) | |
# Save to cache | |
cache_data = { | |
'embeddings': embeddings, | |
'documents': _self.documents | |
} | |
_self.retriever.save_cache(_self.data_folder, cache_data) | |
st.success("Document processing complete!") | |
except Exception as e: | |
logging.error(f"Error in load_and_process_csvs: {str(e)}") | |
raise | |
def preprocess_query(self, query: str) -> str: | |
query = query.lower().strip() | |
query = re.sub(r'\s+', ' ', query) | |
return query | |
def postprocess_response(self, response: str) -> str: | |
response = response.strip() | |
response = re.sub(r'\s+', ' ', response) | |
response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response) | |
return response | |
def process_query(self, query: str, placeholder) -> str: | |
try: | |
if self.llm is None: | |
raise RuntimeError("LLM model not initialized") | |
if self.retriever.model is None: | |
raise RuntimeError("Sentence transformer model not initialized") | |
query = self.preprocess_query(query) | |
status = placeholder.empty() | |
status.write("π Finding relevant information...") | |
query_embedding = self.retriever.encode([query]) | |
indices, scores = self.retriever.search(query_embedding, self.k, self.documents) | |
logging.info("\nSearch Results:") | |
for idx, score in zip(indices.tolist(), scores.tolist()): | |
logging.info(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...") | |
relevant_docs = [self.documents[idx] for idx in indices.tolist()] | |
status.write("π Generating response...") | |
context = "\n".join(relevant_docs) | |
prompt = f"""Context information is below: | |
{context} | |
Given the context above, please answer the following question: | |
{query} | |
Guidelines: | |
- If you cannot answer based on the context, say so politely | |
- Keep the response concise and focused | |
- Only include sports-related information | |
- No dates or timestamps in the response | |
- Use clear, natural language | |
Answer:""" | |
response_placeholder = placeholder.empty() | |
try: | |
response = self.llm( | |
prompt, | |
max_tokens=512, | |
temperature=0.4, | |
top_p=0.95, | |
echo=False, | |
stop=["Question:", "\n\n"] | |
) | |
if response and 'choices' in response and len(response['choices']) > 0: | |
generated_text = response['choices'][0].get('text', '').strip() | |
if generated_text: | |
final_response = self.postprocess_response(generated_text) | |
response_placeholder.markdown(final_response) | |
return final_response | |
else: | |
message = "No relevant answer found. Please try rephrasing your question." | |
response_placeholder.warning(message) | |
return message | |
else: | |
message = "Unable to generate response. Please try again." | |
response_placeholder.warning(message) | |
return message | |
except Exception as e: | |
logging.error(f"Generation error: {str(e)}") | |
message = "Had some trouble generating the response. Please try again." | |
response_placeholder.warning(message) | |
return message | |
except Exception as e: | |
logging.error(f"Process error: {str(e)}") | |
message = "Something went wrong. Please try again with a different question." | |
placeholder.warning(message) | |
return message | |
def initialize_rag_pipeline(): | |
"""Initialize the RAG pipeline once""" | |
try: | |
data_folder = "ESPN_data" | |
if not os.path.exists(data_folder): | |
os.makedirs(data_folder, exist_ok=True) | |
# Check for cache | |
cache_path = os.path.join("embeddings_cache", "embeddings.pkl") | |
if os.path.exists(cache_path): | |
st.info("Found cached data. Loading...") | |
else: | |
st.warning("Initial setup may take several minutes...") | |
rag = RAGPipeline(data_folder) | |
return rag | |
except Exception as e: | |
logging.error(f"Pipeline initialization error: {str(e)}") | |
st.error("Failed to initialize the system. Please check if all required files are present.") | |
raise | |
def main(): | |
try: | |
# Environment check | |
if not check_environment(): | |
return | |
# Page config | |
st.set_page_config( | |
page_title="The Sport Chatbot", | |
page_icon="π", | |
layout="wide" | |
) | |
# Improved CSS styling | |
st.markdown(""" | |
<style> | |
.block-container { | |
padding-top: 2rem; | |
padding-bottom: 2rem; | |
} | |
.stTextInput > div > div > input { | |
width: 100%; | |
} | |
.stButton > button { | |
width: 200px; | |
margin: 0 auto; | |
display: block; | |
background-color: #FF4B4B; | |
color: white; | |
border-radius: 5px; | |
padding: 0.5rem 1rem; | |
} | |
.main-title { | |
text-align: center; | |
padding: 1rem 0; | |
font-size: 3rem; | |
color: #1F1F1F; | |
} | |
.sub-title { | |
text-align: center; | |
padding: 0.5rem 0; | |
font-size: 1.5rem; | |
color: #4F4F4F; | |
} | |
.description { | |
text-align: center; | |
color: #666666; | |
padding: 0.5rem 0; | |
font-size: 1.1rem; | |
line-height: 1.6; | |
margin-bottom: 1rem; | |
} | |
.stMarkdown { | |
max-width: 100%; | |
} | |
.st-emotion-cache-16idsys p { | |
font-size: 1.1rem; | |
line-height: 1.6; | |
} | |
.main-content { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 0 1rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Header section | |
st.markdown("<h1 class='main-title'>π The Sport Chatbot</h1>", unsafe_allow_html=True) | |
st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True) | |
st.markdown(""" | |
<p class='description'> | |
Hey there! π I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. | |
With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024. | |
</p> | |
<p class='description'> | |
Got any general questions? Feel free to askβI'll do my best to provide answers based on the information I've been trained on! | |
</p> | |
""", unsafe_allow_html=True) | |
# Add spacing | |
st.markdown("<br>", unsafe_allow_html=True) | |
# Initialize the pipeline | |
try: | |
with st.spinner("Loading resources..."): | |
rag = initialize_rag_pipeline() | |
# Add a model health check | |
if not rag.check_model_health(): | |
st.error("Model initialization failed. Please try restarting the application.") | |
return | |
except Exception as e: | |
logging.error(f"Initialization error: {str(e)}") | |
st.error("Unable to initialize the system. Please check if all required files are present.") | |
return | |
# Create columns for layout with golden ratio | |
col1, col2, col3 = st.columns([1, 6, 1]) | |
with col2: | |
# Query input with label styling | |
query = st.text_input("What would you like to know about sports?") | |
# Centered button | |
if st.button("Get Answer"): | |
if query: | |
response_placeholder = st.empty() | |
try: | |
response = rag.process_query(query, response_placeholder) | |
logging.info(f"Generated response: {response}") | |
except Exception as e: | |
logging.error(f"Query processing error: {str(e)}") | |
response_placeholder.warning("Unable to process your question. Please try again.") | |
else: | |
st.warning("Please enter a question!") | |
# Footer | |
st.markdown("<br><br>", unsafe_allow_html=True) | |
st.markdown("---") | |
st.markdown(""" | |
<p style='text-align: center; color: #666666; padding: 1rem 0;'> | |
Powered by ESPN Data & Mistral AI π | |
</p> | |
""", unsafe_allow_html=True) | |
except Exception as e: | |
logging.error(f"Application error: {str(e)}") | |
st.error("An unexpected error occurred. Please check the logs and try again.") | |
if __name__ == "__main__": | |
try: | |
main() | |
except Exception as e: | |
logging.error(f"Application error: {str(e)}") | |
st.error("An unexpected error occurred. Please check the logs and try again.") |