Spaces:
Runtime error
Runtime error
import os | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
import streamlit as st | |
import torch | |
import torch.nn.functional as F | |
import re | |
import requests | |
from embedding_processor import SentenceTransformerRetriever, process_data | |
import pickle | |
import logging | |
import sys | |
from llama_cpp import Llama | |
from tqdm import tqdm | |
# At the top of your script | |
os.environ['LLAMA_CPP_THREADS'] = '4' | |
os.environ['LLAMA_CPP_BATCH_SIZE'] = '512' | |
os.environ['LLAMA_CPP_MODEL_PATH'] = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf") | |
# Set page config first | |
st.set_page_config( | |
page_title="The Sport Chatbot", | |
page_icon="π", | |
layout="wide" | |
) | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler(sys.stdout)] | |
) | |
# Add this at the top level of your script, after imports | |
def get_llama_model(): | |
model_path = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf") | |
os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
if not os.path.exists(model_path): | |
st.info("Downloading model... This may take a while.") | |
direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf" | |
download_file_with_progress(direct_url, model_path) | |
llm_config = { | |
"model_path": model_path, | |
"n_ctx": 2048, | |
"n_threads": 4, | |
"n_batch": 512, | |
"n_gpu_layers": 0, | |
"verbose": False, | |
"use_mlock": True | |
} | |
return Llama(**llm_config) | |
def download_file_with_progress(url: str, filename: str): | |
"""Download a file with progress bar using requests""" | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get('content-length', 0)) | |
with open(filename, 'wb') as file, tqdm( | |
desc=filename, | |
total=total_size, | |
unit='iB', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as progress_bar: | |
for data in response.iter_content(chunk_size=1024): | |
size = file.write(data) | |
progress_bar.update(size) | |
def load_from_drive(file_id: str): | |
"""Load pickle file directly from Google Drive""" | |
try: | |
url = f"https://drive.google.com/uc?id={file_id}&export=download" | |
session = requests.Session() | |
response = session.get(url, stream=True) | |
for key, value in response.cookies.items(): | |
if key.startswith('download_warning'): | |
url = f"{url}&confirm={value}" | |
response = session.get(url, stream=True) | |
break | |
content = response.content | |
print(f"Successfully downloaded {len(content)} bytes") | |
return pickle.loads(content) | |
except Exception as e: | |
print(f"Detailed error: {str(e)}") | |
st.error(f"Error loading file from Drive: {str(e)}") | |
return None | |
# @st.cache_resource(show_spinner=False) | |
# def load_llama_model(): | |
# """Load Llama model with caching""" | |
# try: | |
# model_path = "mistral-7b-v0.1.Q4_K_M.gguf" | |
# if not os.path.exists(model_path): | |
# st.info("Downloading model... This may take a while.") | |
# direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf" | |
# download_file_with_progress(direct_url, model_path) | |
# llm_config = { | |
# "model_path": model_path, | |
# "n_ctx": 2048, | |
# "n_threads": 4, | |
# "n_batch": 512, | |
# "n_gpu_layers": 0, | |
# "verbose": False | |
# } | |
# model = Llama(**llm_config) | |
# st.success("Model loaded successfully!") | |
# return model | |
# except Exception as e: | |
# st.error(f"Error loading model: {str(e)}") | |
# raise | |
def load_llama_model(): | |
"""Load Llama model with caching""" | |
try: | |
model_path = "mistral-7b-v0.1.Q4_K_M.gguf" | |
if not os.path.exists(model_path): | |
st.info("Downloading model... This may take a while.") | |
direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf" | |
download_file_with_progress(direct_url, model_path) | |
if not os.path.exists(model_path): | |
raise FileNotFoundError("Model file not found after download attempt") | |
if os.path.getsize(model_path) < 1000000: # Less than 1MB | |
raise ValueError("Model file is too small, likely corrupted") | |
llm_config = { | |
"model_path": model_path, | |
"n_ctx": 2048, | |
"n_threads": 4, | |
"n_batch": 512, | |
"n_gpu_layers": 0, | |
"verbose": True # Enable verbose mode for debugging | |
} | |
logging.info("Initializing Llama model...") | |
model = Llama(**llm_config) | |
# Test the model | |
logging.info("Testing model...") | |
test_response = model("Test", max_tokens=10) | |
if not test_response: | |
raise RuntimeError("Model test failed") | |
logging.info("Model loaded and tested successfully") | |
st.success("Model loaded successfully!") | |
return model | |
except Exception as e: | |
logging.error(f"Error loading model: {str(e)}") | |
logging.error("Full error details: ", exc_info=True) | |
raise | |
def check_environment(): | |
"""Check if the environment is properly set up""" | |
try: | |
import torch | |
import sentence_transformers | |
return True | |
except ImportError as e: | |
st.error(f"Missing required package: {str(e)}") | |
st.stop() | |
return False | |
class RAGPipeline: | |
def __init__(self, data_folder: str, k: int = 5): | |
self.data_folder = data_folder | |
self.k = k | |
self.retriever = SentenceTransformerRetriever() | |
self.documents = [] | |
self.device = torch.device("cpu") | |
# Use the cached model directly | |
self.llm = get_llama_model() | |
def preprocess_query(self, query: str) -> str: | |
"""Clean and prepare the query""" | |
query = query.lower().strip() | |
query = re.sub(r'\s+', ' ', query) | |
return query | |
### Added on Nov 2, 2024 | |
# def postprocess_response(self, response: str) -> str: | |
# """Clean up the generated response""" | |
# response = response.strip() | |
# response = re.sub(r'\s+', ' ', response) | |
# response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response) | |
# return response | |
# def query_model(self, prompt: str) -> str: | |
# """Query the local Llama model""" | |
# try: | |
# if self.llm is None: | |
# raise RuntimeError("Model not initialized") | |
# response = self.llm( | |
# prompt, | |
# max_tokens=512, | |
# temperature=0.4, | |
# top_p=0.95, | |
# echo=False, | |
# stop=["Question:", "\n\n"] | |
# ) | |
# if response and 'choices' in response and len(response['choices']) > 0: | |
# text = response['choices'][0].get('text', '').strip() | |
# return text | |
# else: | |
# raise ValueError("No valid response generated") | |
# except Exception as e: | |
# logging.error(f"Error in query_model: {str(e)}") | |
# raise | |
# def process_query(self, query: str, placeholder) -> str: | |
# try: | |
# # Preprocess query | |
# query = self.preprocess_query(query) | |
# # Show retrieval status | |
# status = placeholder.empty() | |
# status.write("π Finding relevant information...") | |
# # Get embeddings and search | |
# query_embedding = self.retriever.encode([query]) | |
# similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings) | |
# scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents))) | |
# relevant_docs = [self.documents[idx] for idx in indices.tolist()] | |
# # Update status | |
# status.write("π Generating response...") | |
# # Prepare context and prompt | |
# context = "\n".join(relevant_docs[:3]) | |
# prompt = f"""Context information is below: | |
# {context} | |
# Given the context above, please answer the following question: | |
# {query} | |
# Guidelines: | |
# - If you cannot answer based on the context, say so politely | |
# - Keep the response concise and focused | |
# - Only include sports-related information | |
# - No dates or timestamps in the response | |
# - Use clear, natural language | |
# Answer:""" | |
# # Generate response | |
# response_placeholder = placeholder.empty() | |
# try: | |
# response_text = self.query_model(prompt) | |
# if response_text: | |
# final_response = self.postprocess_response(response_text) | |
# response_placeholder.markdown(final_response) | |
# return final_response | |
# else: | |
# message = "No relevant answer found. Please try rephrasing your question." | |
# response_placeholder.warning(message) | |
# return message | |
# except Exception as e: | |
# logging.error(f"Generation error: {str(e)}") | |
# message = "Had some trouble generating the response. Please try again." | |
# response_placeholder.warning(message) | |
# return message | |
# except Exception as e: | |
# logging.error(f"Process error: {str(e)}") | |
# message = "Something went wrong. Please try again with a different question." | |
# placeholder.warning(message) | |
# return message | |
# def process_query(self, query: str, placeholder) -> str: | |
# try: | |
# # Preprocess query | |
# query = self.preprocess_query(query) | |
# logging.info(f"Processing query: {query}") | |
# # Show retrieval status | |
# status = placeholder.empty() | |
# status.write("π Finding relevant information...") | |
# # Get embeddings and search | |
# query_embedding = self.retriever.encode([query]) | |
# similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings) | |
# scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents))) | |
# # Log similarity scores | |
# for idx, score in zip(indices.tolist(), scores.tolist()): | |
# logging.info(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...") | |
# relevant_docs = [self.documents[idx] for idx in indices.tolist()] | |
# # Update status | |
# status.write("π Generating response...") | |
# # Prepare context and prompt | |
# context = "\n".join(relevant_docs[:3]) | |
# prompt = f"""Context information is below: | |
# {context} | |
# Given the context above, please answer the following question: | |
# {query} | |
# Guidelines: | |
# - If you cannot answer based on the context, say so politely | |
# - Keep the response concise and focused | |
# - Only include sports-related information | |
# - No dates or timestamps in the response | |
# - Use clear, natural language | |
# Answer:""" | |
# # Generate response | |
# response_placeholder = placeholder.empty() | |
# try: | |
# # Add logging for model state | |
# logging.info("Model state check - Is None?: " + str(self.llm is None)) | |
# # Directly use Llama model | |
# response = self.llm( | |
# prompt, | |
# max_tokens=512, | |
# temperature=0.4, | |
# top_p=0.95, | |
# echo=False, | |
# stop=["Question:", "\n\n"] | |
# ) | |
# logging.info(f"Raw model response: {response}") | |
# if response and isinstance(response, dict) and 'choices' in response: | |
# generated_text = response['choices'][0].get('text', '').strip() | |
# if generated_text: | |
# final_response = self.postprocess_response(generated_text) | |
# response_placeholder.markdown(final_response) | |
# return final_response | |
# message = "No relevant answer found. Please try rephrasing your question." | |
# response_placeholder.warning(message) | |
# return message | |
# except Exception as e: | |
# logging.error(f"Generation error: {str(e)}") | |
# logging.error(f"Full error details: ", exc_info=True) | |
# message = f"Had some trouble generating the response: {str(e)}" | |
# response_placeholder.warning(message) | |
# return message | |
# except Exception as e: | |
# logging.error(f"Process error: {str(e)}") | |
# logging.error(f"Full error details: ", exc_info=True) | |
# message = f"Something went wrong: {str(e)}" | |
# placeholder.warning(message) | |
# return message | |
### Added on Nov 2, 2024 | |
def postprocess_response(self, response: str) -> str: | |
"""Clean up the generated response""" | |
try: | |
# Remove datetime patterns and other unwanted content | |
response = re.sub(r'\d{4}-\d{2}-\d{2}(?:T|\s)\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:?\d{2})?', '', response) | |
response = re.sub(r'User \d+:.*?(?=User \d+:|$)', '', response) | |
response = re.sub(r'\d{2}:\d{2}(?::\d{2})?(?:\s?(?:AM|PM))?', '', response) | |
response = re.sub(r'\d{1,2}[-/]\d{1,2}[-/]\d{2,4}', '', response) | |
response = re.sub(r'(?m)^User \d+:', '', response) | |
# Clean up spacing but preserve intentional paragraph breaks | |
# Replace multiple newlines with two newlines (one paragraph break) | |
response = re.sub(r'\n\s*\n\s*\n+', '\n\n', response) | |
# Replace multiple spaces with single space | |
response = re.sub(r' +', ' ', response) | |
# Clean up beginning/end | |
response = response.strip() | |
return response | |
except Exception as e: | |
logging.error(f"Error in postprocess_response: {str(e)}") | |
return response | |
def process_query(self, query: str, placeholder) -> str: | |
try: | |
# Verify this is the current query being processed | |
if hasattr(st.session_state, 'current_query') and query != st.session_state.current_query: | |
logging.warning(f"Skipping outdated query: {query}") | |
return "" | |
query = self.preprocess_query(query) | |
status = placeholder.empty() | |
status.write("π Finding relevant information...") | |
query_embedding = self.retriever.encode([query]) | |
similarities = F.cosine_similarity(query_embedding, self.retriever.doc_embeddings) | |
scores, indices = torch.topk(similarities, k=min(self.k, len(self.documents))) | |
relevant_docs = [self.documents[idx] for idx in indices.tolist()] | |
cleaned_docs = [] | |
for doc in relevant_docs[:3]: | |
cleaned_text = self.postprocess_response(doc) | |
if cleaned_text: | |
cleaned_docs.append(cleaned_text) | |
status.write("π Generating response...") | |
prompt = f"""Context information is below: | |
{' '.join(cleaned_docs)} | |
Given the context above, please answer the following question: | |
{query} | |
Guidelines for your response: | |
- Structure your response in clear, logical paragraphs | |
- Start a new paragraph for each new main point or aspect | |
- If listing multiple items, use separate paragraphs | |
- Keep each paragraph focused on a single topic or point | |
- Use natural paragraph breaks where the content shifts focus | |
- Maintain clear transitions between paragraphs | |
- If providing statistics or achievements, group them logically | |
- If describing different aspects (e.g., career, playing style, achievements), use separate paragraphs | |
- Keep paragraphs concise but complete | |
- Exclude any dates, timestamps, or user comments | |
- Focus on factual sports information | |
- If you cannot answer based on the context, say so politely | |
Format your response with proper paragraph breaks where appropriate. | |
Answer:""" | |
response_placeholder = placeholder.empty() | |
try: | |
response_text = self.query_model(prompt) | |
if response_text: | |
# Clean up the response while preserving paragraph structure | |
final_response = self.postprocess_response(response_text) | |
# Convert cleaned response to markdown with proper paragraph spacing | |
markdown_response = final_response.replace('\n\n', '\n\n \n\n') | |
response_placeholder.markdown(markdown_response) | |
return final_response | |
else: | |
message = "No relevant answer found. Please try rephrasing your question." | |
response_placeholder.warning(message) | |
return message | |
except Exception as e: | |
logging.error(f"Generation error: {str(e)}") | |
message = "Had some trouble generating the response. Please try again." | |
response_placeholder.warning(message) | |
return message | |
except Exception as e: | |
logging.error(f"Process error: {str(e)}") | |
message = "Something went wrong. Please try again with a different question." | |
placeholder.warning(message) | |
return message | |
# def query_model(self, prompt: str) -> str: | |
# """Query the local Llama model""" | |
# try: | |
# if self.llm is None: | |
# raise RuntimeError("Model not initialized") | |
# response = self.llm( | |
# prompt, | |
# max_tokens=512, | |
# temperature=0.4, | |
# top_p=0.95, | |
# echo=False, | |
# stop=["Question:", "Context:", "Guidelines:"], # Removed \n\n from stop tokens to allow paragraphs | |
# repeat_penalty=1.1 # Added to encourage more diverse text | |
# ) | |
# if response and 'choices' in response and len(response['choices']) > 0: | |
# text = response['choices'][0].get('text', '').strip() | |
# return text | |
# else: | |
# raise ValueError("No valid response generated") | |
# except Exception as e: | |
# logging.error(f"Error in query_model: {str(e)}") | |
# raise | |
def query_model(self, prompt: str) -> str: | |
"""Query the local Llama model""" | |
try: | |
if self.llm is None: | |
raise RuntimeError("Model not initialized") | |
# Log the prompt for debugging | |
logging.info(f"Sending prompt to model...") | |
# Generate response with more explicit parameters | |
response = self.llm( | |
prompt, | |
max_tokens=512, # Maximum length of the response | |
temperature=0.7, # Slightly increased for more dynamic responses | |
top_p=0.95, # Nucleus sampling parameter | |
top_k=50, # Top-k sampling parameter | |
echo=False, # Don't include prompt in response | |
stop=["Question:", "Context:", "Guidelines:"], # Stop tokens | |
repeat_penalty=1.1, # Penalize repetition | |
presence_penalty=0.5, # Encourage topic diversity | |
frequency_penalty=0.5 # Discourage word repetition | |
) | |
# Log the raw response for debugging | |
logging.info(f"Raw model response: {response}") | |
if response and isinstance(response, dict) and 'choices' in response and response['choices']: | |
generated_text = response['choices'][0].get('text', '').strip() | |
if generated_text: | |
logging.info(f"Generated text: {generated_text[:100]}...") # Log first 100 chars | |
return generated_text | |
else: | |
logging.warning("Model returned empty response") | |
raise ValueError("Empty response from model") | |
else: | |
logging.warning(f"Unexpected response format: {response}") | |
raise ValueError("Invalid response format from model") | |
except Exception as e: | |
logging.error(f"Error in query_model: {str(e)}") | |
logging.error("Full error details: ", exc_info=True) | |
raise | |
def initialize_model(self): | |
"""Initialize the model with proper error handling and verification""" | |
try: | |
if not os.path.exists(self.model_path): | |
st.info("Downloading model... This may take a while.") | |
direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf" | |
download_file_with_progress(direct_url, self.model_path) | |
# Verify file exists and has content | |
if not os.path.exists(self.model_path): | |
raise FileNotFoundError(f"Model file {self.model_path} not found after download attempts") | |
if os.path.getsize(self.model_path) < 1000000: # Less than 1MB | |
os.remove(self.model_path) | |
raise ValueError("Downloaded model file is too small, likely corrupted") | |
# Updated model configuration | |
llm_config = { | |
"model_path": self.model_path, | |
"n_ctx": 4096, # Increased context window | |
"n_threads": 4, | |
"n_batch": 512, | |
"n_gpu_layers": 0, | |
"verbose": True, # Enable verbose mode for debugging | |
"use_mlock": False, # Disable memory locking | |
"last_n_tokens_size": 64, # Token window size for repeat penalty | |
"seed": -1 # Random seed for reproducibility | |
} | |
logging.info("Initializing Llama model...") | |
self.llm = Llama(**llm_config) | |
# Test the model | |
test_response = self.llm( | |
"Test response", | |
max_tokens=10, | |
temperature=0.7, | |
echo=False | |
) | |
if not test_response or 'choices' not in test_response: | |
raise RuntimeError("Model initialization test failed") | |
logging.info("Model initialized and tested successfully") | |
return self.llm | |
except Exception as e: | |
logging.error(f"Error initializing model: {str(e)}") | |
raise | |
# @st.cache_resource(show_spinner=False) | |
# def initialize_rag_pipeline(): | |
# """Initialize the RAG pipeline once""" | |
# try: | |
# # Create necessary directories | |
# os.makedirs("ESPN_data", exist_ok=True) | |
# # Load embeddings from Drive | |
# drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN" | |
# with st.spinner("Loading embeddings from Google Drive..."): | |
# cache_data = load_from_drive(drive_file_id) | |
# if cache_data is None: | |
# st.error("Failed to load embeddings from Google Drive") | |
# st.stop() | |
# # Initialize pipeline | |
# data_folder = "ESPN_data" | |
# rag = RAGPipeline(data_folder) | |
# # Store embeddings | |
# rag.documents = cache_data['documents'] | |
# rag.retriever.store_embeddings(cache_data['embeddings']) | |
# return rag | |
# except Exception as e: | |
# logging.error(f"Pipeline initialization error: {str(e)}") | |
# st.error(f"Failed to initialize the system: {str(e)}") | |
# raise | |
def initialize_rag_pipeline(): | |
"""Initialize the RAG pipeline once""" | |
try: | |
data_folder = "ESPN_data" | |
if not os.path.exists(data_folder): | |
os.makedirs(data_folder, exist_ok=True) | |
# Load embeddings first | |
drive_file_id = "1MuV63AE9o6zR9aBvdSDQOUextp71r2NN" | |
with st.spinner("Loading data..."): | |
cache_data = load_from_drive(drive_file_id) | |
if cache_data is None: | |
st.error("Failed to load embeddings from Google Drive") | |
st.stop() | |
# Initialize pipeline | |
rag = RAGPipeline(data_folder) | |
# Store embeddings | |
rag.documents = cache_data['documents'] | |
rag.retriever.store_embeddings(cache_data['embeddings']) | |
return rag | |
except Exception as e: | |
logging.error(f"Pipeline initialization error: {str(e)}") | |
st.error(f"Failed to initialize the system: {str(e)}") | |
raise | |
# def main(): | |
# try: | |
# # Environment check | |
# if not check_environment(): | |
# return | |
# # Improved CSS styling | |
# st.markdown(""" | |
# <style> | |
# /* Container styling */ | |
# .block-container { | |
# padding-top: 2rem; | |
# padding-bottom: 2rem; | |
# } | |
# /* Text input styling */ | |
# .stTextInput > div > div > input { | |
# width: 100%; | |
# } | |
# /* Button styling */ | |
# .stButton > button { | |
# width: 200px; | |
# margin: 0 auto; | |
# display: block; | |
# background-color: #FF4B4B; | |
# color: white; | |
# border-radius: 5px; | |
# padding: 0.5rem 1rem; | |
# } | |
# /* Title styling */ | |
# .main-title { | |
# text-align: center; | |
# padding: 1rem 0; | |
# font-size: 3rem; | |
# color: #1F1F1F; | |
# } | |
# .sub-title { | |
# text-align: center; | |
# padding: 0.5rem 0; | |
# font-size: 1.5rem; | |
# color: #4F4F4F; | |
# } | |
# /* Description styling */ | |
# .description { | |
# text-align: center; | |
# color: #666666; | |
# padding: 0.5rem 0; | |
# font-size: 1.1rem; | |
# line-height: 1.6; | |
# margin-bottom: 1rem; | |
# } | |
# /* Answer container styling */ | |
# .stMarkdown { | |
# max-width: 100%; | |
# } | |
# /* Streamlit default overrides */ | |
# .st-emotion-cache-16idsys p { | |
# font-size: 1.1rem; | |
# line-height: 1.6; | |
# } | |
# /* Container for main content */ | |
# .main-content { | |
# max-width: 1200px; | |
# margin: 0 auto; | |
# padding: 0 1rem; | |
# } | |
# </style> | |
# """, unsafe_allow_html=True) | |
# # Header section | |
# st.markdown("<h1 class='main-title'>π The Sport Chatbot</h1>", unsafe_allow_html=True) | |
# st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True) | |
# st.markdown(""" | |
# <p class='description'> | |
# Hey there! π I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. | |
# With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024. | |
# </p> | |
# <p class='description'> | |
# Got any general questions? Feel free to askβI'll do my best to provide answers based on the information I've been trained on! | |
# </p> | |
# """, unsafe_allow_html=True) | |
# # Initialize the pipeline | |
# if 'rag' not in st.session_state: | |
# with st.spinner("Loading resources..."): | |
# st.session_state.rag = initialize_rag_pipeline() | |
# # Create columns for layout | |
# col1, col2, col3 = st.columns([1, 6, 1]) | |
# with col2: | |
# # Query input | |
# query = st.text_input("What would you like to know about sports?") | |
# if st.button("Get Answer"): | |
# if query: | |
# response_placeholder = st.empty() | |
# try: | |
# response = st.session_state.rag.process_query(query, response_placeholder) | |
# logging.info(f"Generated response: {response}") | |
# except Exception as e: | |
# logging.error(f"Query processing error: {str(e)}") | |
# response_placeholder.warning("Unable to process your question. Please try again.") | |
# else: | |
# st.warning("Please enter a question!") | |
# # Footer | |
# st.markdown("<br><br>", unsafe_allow_html=True) | |
# st.markdown("---") | |
# st.markdown(""" | |
# <p style='text-align: center; color: #666666; padding: 1rem 0;'> | |
# Powered by ESPN Data & Mistral AI π | |
# </p> | |
# """, unsafe_allow_html=True) | |
# except Exception as e: | |
# logging.error(f"Application error: {str(e)}") | |
# st.error("An unexpected error occurred. Please check the logs and try again.") | |
# def main(): | |
# try: | |
# # Environment check | |
# if not check_environment(): | |
# return | |
# # Improved CSS styling | |
# st.markdown(""" | |
# <style> | |
# /* Container styling */ | |
# .block-container { | |
# padding-top: 2rem; | |
# padding-bottom: 2rem; | |
# } | |
# /* Text input styling */ | |
# .stTextInput > div > div > input { | |
# width: 100%; | |
# } | |
# /* Button styling */ | |
# .stButton > button { | |
# width: 200px; | |
# margin: 0 auto; | |
# display: block; | |
# background-color: #FF4B4B; | |
# color: white; | |
# border-radius: 5px; | |
# padding: 0.5rem 1rem; | |
# } | |
# /* Title styling */ | |
# .main-title { | |
# text-align: center; | |
# padding: 1rem 0; | |
# font-size: 3rem; | |
# color: #1F1F1F; | |
# } | |
# .sub-title { | |
# text-align: center; | |
# padding: 0.5rem 0; | |
# font-size: 1.5rem; | |
# color: #4F4F4F; | |
# } | |
# /* Description styling */ | |
# .description { | |
# text-align: center; | |
# color: #666666; | |
# padding: 0.5rem 0; | |
# font-size: 1.1rem; | |
# line-height: 1.6; | |
# margin-bottom: 1rem; | |
# } | |
# /* Answer container styling */ | |
# .stMarkdown { | |
# max-width: 100%; | |
# } | |
# /* Streamlit default overrides */ | |
# .st-emotion-cache-16idsys p { | |
# font-size: 1.1rem; | |
# line-height: 1.6; | |
# } | |
# /* Container for main content */ | |
# .main-content { | |
# max-width: 1200px; | |
# margin: 0 auto; | |
# padding: 0 1rem; | |
# } | |
# </style> | |
# """, unsafe_allow_html=True) | |
# # Header section | |
# st.markdown("<h1 class='main-title'>π The Sport Chatbot</h1>", unsafe_allow_html=True) | |
# st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True) | |
# st.markdown(""" | |
# <p class='description'> | |
# Hey there! π I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. | |
# With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024. | |
# </p> | |
# <p class='description'> | |
# Got any general questions? Feel free to askβI'll do my best to provide answers based on the information I've been trained on! | |
# </p> | |
# """, unsafe_allow_html=True) | |
# # Initialize the pipeline with better error handling | |
# if 'rag' not in st.session_state: | |
# try: | |
# with st.spinner("Loading resources..."): | |
# st.session_state.rag = initialize_rag_pipeline() | |
# logging.info("Pipeline initialized successfully") | |
# except Exception as e: | |
# logging.error(f"Pipeline initialization error: {str(e)}") | |
# st.error("Failed to initialize the system. Please check the logs.") | |
# st.stop() | |
# return | |
# # Create columns for layout | |
# col1, col2, col3 = st.columns([1, 6, 1]) | |
# with col2: | |
# # Query input | |
# query = st.text_input("What would you like to know about sports?") | |
# if st.button("Get Answer"): | |
# if query: | |
# response_placeholder = st.empty() | |
# try: | |
# # Log query processing start | |
# logging.info(f"Processing query: {query}") | |
# # Process query and get response | |
# response = st.session_state.rag.process_query(query, response_placeholder) | |
# # Log successful response | |
# logging.info(f"Generated response: {response}") | |
# except Exception as e: | |
# # Log error details | |
# logging.error(f"Query processing error: {str(e)}") | |
# logging.error("Full error details: ", exc_info=True) | |
# response_placeholder.warning("Unable to process your question. Please try again.") | |
# else: | |
# st.warning("Please enter a question!") | |
# # Footer | |
# st.markdown("<br><br>", unsafe_allow_html=True) | |
# st.markdown("---") | |
# st.markdown(""" | |
# <p style='text-align: center; color: #666666; padding: 1rem 0;'> | |
# Powered by ESPN Data & Mistral AI π | |
# </p> | |
# """, unsafe_allow_html=True) | |
# except Exception as e: | |
# logging.error(f"Application error: {str(e)}") | |
# logging.error("Full error details: ", exc_info=True) | |
# st.error("An unexpected error occurred. Please check the logs and try again.") | |
# if __name__ == "__main__": | |
# # Configure logging | |
# logging.basicConfig( | |
# level=logging.INFO, | |
# format='%(asctime)s - %(levelname)s - %(message)s' | |
# ) | |
# try: | |
# main() | |
# except Exception as e: | |
# logging.error(f"Fatal error: {str(e)}") | |
# logging.error("Full error details: ", exc_info=True) | |
# st.error("A fatal error occurred. Please check the logs and try again.") | |
# if __name__ == "__main__": | |
# main() | |
def main(): | |
try: | |
# First, check if model exists | |
model_path = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf") | |
if not os.path.exists(model_path): | |
st.warning("β οΈ First-time setup: The model will be downloaded. This takes a few minutes but only happens once.") | |
# Environment check | |
if not check_environment(): | |
return | |
# Initialize session state variables | |
if 'current_query' not in st.session_state: | |
st.session_state.current_query = None | |
if 'processing' not in st.session_state: | |
st.session_state.processing = False | |
# Improved CSS styling | |
st.markdown(""" | |
<style> | |
/* Container styling */ | |
.block-container { | |
padding-top: 2rem; | |
padding-bottom: 2rem; | |
} | |
/* Text input styling */ | |
.stTextInput > div > div > input { | |
width: 100%; | |
} | |
/* Button styling */ | |
.stButton > button { | |
width: 200px; | |
margin: 0 auto; | |
display: block; | |
background-color: #FF4B4B; | |
color: white; | |
border-radius: 5px; | |
padding: 0.5rem 1rem; | |
} | |
/* Title styling */ | |
.main-title { | |
text-align: center; | |
padding: 1rem 0; | |
font-size: 3rem; | |
color: #1F1F1F; | |
} | |
.sub-title { | |
text-align: center; | |
padding: 0.5rem 0; | |
font-size: 1.5rem; | |
color: #4F4F4F; | |
} | |
/* Description styling */ | |
.description { | |
text-align: center; | |
color: #666666; | |
padding: 0.5rem 0; | |
font-size: 1.1rem; | |
line-height: 1.6; | |
margin-bottom: 1rem; | |
} | |
/* Answer container styling */ | |
.stMarkdown { | |
max-width: 100%; | |
} | |
/* Streamlit default overrides */ | |
.st-emotion-cache-16idsys p { | |
font-size: 1.1rem; | |
line-height: 1.6; | |
} | |
/* Container for main content */ | |
.main-content { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 0 1rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Header section | |
st.markdown("<h1 class='main-title'>π The Sport Chatbot</h1>", unsafe_allow_html=True) | |
st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True) | |
st.markdown(""" | |
<p class='description'> | |
Hey there! π I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. | |
With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024. | |
</p> | |
<p class='description'> | |
Got any general questions? Feel free to askβI'll do my best to provide answers based on the information I've been trained on! | |
</p> | |
""", unsafe_allow_html=True) | |
# Initialize the pipeline | |
if 'rag' not in st.session_state: | |
try: | |
with st.spinner("Loading resources..."): | |
st.session_state.rag = initialize_rag_pipeline() | |
logging.info("Pipeline initialized successfully") | |
except Exception as e: | |
logging.error(f"Pipeline initialization error: {str(e)}") | |
st.error("Failed to initialize the system. Please check the logs.") | |
st.stop() | |
return | |
# Create columns for layout | |
col1, col2, col3 = st.columns([1, 6, 1]) | |
with col2: | |
# Query input with unique key | |
query = st.text_input( | |
"What would you like to know about sports?", | |
key="sports_query" | |
) | |
# Centered button with unique key | |
if st.button("Get Answer", key="answer_button"): | |
if query: | |
# Clear any previous response | |
if 'response_placeholder' in st.session_state: | |
st.session_state.response_placeholder.empty() | |
response_placeholder = st.empty() | |
st.session_state.response_placeholder = response_placeholder | |
try: | |
# Update current query and processing state | |
st.session_state.current_query = query | |
st.session_state.processing = True | |
# Log query processing start | |
logging.info(f"Processing query: {query}") | |
with st.spinner("Processing your question..."): | |
# Process query and get response | |
response = st.session_state.rag.process_query(query, response_placeholder) | |
# Log successful response | |
logging.info(f"Generated response: {response}") | |
# Reset processing state | |
st.session_state.processing = False | |
except Exception as e: | |
# Log error details | |
logging.error(f"Query processing error: {str(e)}") | |
logging.error("Full error details: ", exc_info=True) | |
response_placeholder.warning("Unable to process your question. Please try again.") | |
st.session_state.processing = False | |
else: | |
st.warning("Please enter a question!") | |
# Footer | |
st.markdown("<br><br>", unsafe_allow_html=True) | |
st.markdown("---") | |
st.markdown(""" | |
<p style='text-align: center; color: #666666; padding: 1rem 0;'> | |
Powered by ESPN Data & Mistral AI π | |
</p> | |
""", unsafe_allow_html=True) | |
except Exception as e: | |
logging.error(f"Application error: {str(e)}") | |
logging.error("Full error details: ", exc_info=True) | |
st.error("An unexpected error occurred. Please check the logs and try again.") | |
if __name__ == "__main__": | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
try: | |
main() | |
except Exception as e: | |
logging.error(f"Fatal error: {str(e)}") | |
logging.error("Full error details: ", exc_info=True) | |
st.error("A fatal error occurred. Please check the logs and try again.") |