Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import torch | |
from typing import List, Dict, Any | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_groq import ChatGroq | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import RetrievalQA | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
class AdvancedRAGChatbot: | |
def __init__(self, | |
embedding_model: str = "BAAI/bge-large-en-v1.5", | |
llm_model: str = "llama-3.3-70b-versatile", | |
temperature: float = 0.7, | |
retrieval_k: int = 5): | |
"""Initialize the Advanced RAG Chatbot with configurable parameters""" | |
self.embeddings = self._configure_embeddings(embedding_model) | |
self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2') | |
self.sentiment_analyzer = pipeline("sentiment-analysis") | |
self.ner_pipeline = pipeline("ner", aggregation_strategy="simple") | |
self.llm = self._configure_llm(llm_model, temperature) | |
self.vector_db = self._initialize_vector_database() | |
self.retriever = self._configure_retriever(retrieval_k) | |
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
self.qa_chain = self._create_conversational_retrieval_chain() | |
def _configure_embeddings(self, model_name: str): | |
"""Configure embeddings with normalization""" | |
encode_kwargs = {'normalize_embeddings': True, 'show_progress_bar': True} | |
return HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs) | |
def _configure_llm(self, model_name: str, temperature: float): | |
"""Configure the Language Model with Groq""" | |
return ChatGroq( | |
model_name=model_name, | |
temperature=temperature, | |
max_tokens=4096, | |
streaming=True | |
) | |
def _initialize_vector_database(self, persist_directory: str = 'vector_db'): | |
"""Initialize the vector database""" | |
return Chroma(persist_directory=persist_directory, embedding_function=self.embeddings) | |
def _configure_retriever(self, retrieval_k: int): | |
"""Configure the document retriever""" | |
return self.vector_db.as_retriever( | |
search_kwargs={ | |
"k": retrieval_k, | |
"search_type": "mmr", | |
"fetch_k": 20 | |
} | |
) | |
def _create_conversational_retrieval_chain(self): | |
"""Create the conversational retrieval chain""" | |
template = """ | |
You are a helpful AI assistant. Provide a precise and comprehensive answer | |
based on the context and chat history. | |
Context: {context} | |
Chat History: {chat_history} | |
Question: {question} | |
Helpful Answer:""" | |
prompt = ChatPromptTemplate.from_template(template) | |
return ConversationalRetrievalChain.from_llm( | |
llm=self.llm, | |
retriever=self.retriever, | |
memory=self.memory, | |
combine_docs_chain_kwargs={'prompt': prompt}, | |
return_source_documents=True | |
) | |
def process_query(self, query: str) -> Dict[str, Any]: | |
"""Process the user query with multiple NLP techniques""" | |
# Advanced NLP Analysis | |
semantic_score = self.semantic_model.encode([query])[0] | |
sentiment_result = self.sentiment_analyzer(query)[0] | |
entities = self.ner_pipeline(query) | |
# RAG Query Processing | |
result = self.qa_chain({"question": query}) | |
return { | |
"response": result['answer'], | |
"source_documents": result.get('source_documents', []), | |
"semantic_similarity": semantic_score.tolist(), | |
"sentiment": sentiment_result, | |
"named_entities": entities | |
} | |
def main(): | |
# Page Configuration | |
st.set_page_config( | |
page_title="Advanced RAG Chatbot", | |
page_icon="π§ ", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Sidebar Configuration | |
with st.sidebar: | |
st.header("π§ Chatbot Settings") | |
st.markdown("Customize your AI assistant's behavior") | |
# Model Configuration | |
embedding_model = st.selectbox( | |
"Embedding Model", | |
["BAAI/bge-large-en-v1.5", "sentence-transformers/all-MiniLM-L6-v2"] | |
) | |
temperature = st.slider("Creativity Level", 0.0, 1.0, 0.7, help="Higher values make responses more creative") | |
retrieval_k = st.slider("Context Depth", 1, 10, 5, help="Number of reference documents to retrieve") | |
# Additional Controls | |
st.divider() | |
reset_chat = st.button("π Reset Conversation") | |
# Initialize Chatbot | |
chatbot = AdvancedRAGChatbot( | |
embedding_model=embedding_model, | |
temperature=temperature, | |
retrieval_k=retrieval_k | |
) | |
# Main Chat Interface | |
st.title("π€ Advanced RAG Chatbot") | |
# Two-column layout | |
col1, col2 = st.columns(2) | |
with col1: | |
st.header("Input") | |
# Chat input with placeholder | |
user_input = st.text_area( | |
"Ask your question", | |
placeholder="Enter your query here...", | |
height=250 | |
) | |
# Submit button | |
submit_button = st.button("Send Query", type="primary") | |
with col2: | |
st.header("Response") | |
# Response container | |
if submit_button and user_input: | |
with st.spinner("Processing your query..."): | |
try: | |
response = chatbot.process_query(user_input) | |
# Bot Response | |
st.markdown("#### Bot's Answer") | |
st.write(response['response']) | |
# Sentiment Analysis | |
st.markdown("#### Sentiment Analysis") | |
sentiment = response['sentiment'] | |
st.metric( | |
label="Sentiment", | |
value=sentiment['label'], | |
delta=f"{sentiment['score']:.2%}" | |
) | |
# Named Entities | |
st.markdown("#### Detected Entities") | |
for entity in response['named_entities']: | |
st.text(f"{entity['word']} ({entity['entity']})") | |
# Source Documents | |
if response['source_documents']: | |
st.markdown("#### Reference Documents") | |
for i, doc in enumerate(response['source_documents'], 1): | |
with st.expander(f"Document {i}"): | |
st.write(doc.page_content) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
else: | |
st.info("Submit a query to see the AI's response") | |
if __name__ == "__main__": | |
main() |