NewsInferno / app.py
HEHEBOIOG's picture
Update app.py
187e418 verified
raw
history blame
7.36 kB
import os
import streamlit as st
import torch
from typing import List, Dict, Any
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import numpy as np
class AdvancedRAGChatbot:
def __init__(self,
embedding_model: str = "BAAI/bge-large-en-v1.5",
llm_model: str = "llama-3.3-70b-versatile",
temperature: float = 0.7,
retrieval_k: int = 5):
"""Initialize the Advanced RAG Chatbot with configurable parameters"""
self.embeddings = self._configure_embeddings(embedding_model)
self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
self.sentiment_analyzer = pipeline("sentiment-analysis")
self.ner_pipeline = pipeline("ner", aggregation_strategy="simple")
self.llm = self._configure_llm(llm_model, temperature)
self.vector_db = self._initialize_vector_database()
self.retriever = self._configure_retriever(retrieval_k)
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
self.qa_chain = self._create_conversational_retrieval_chain()
def _configure_embeddings(self, model_name: str):
"""Configure embeddings with normalization"""
encode_kwargs = {'normalize_embeddings': True, 'show_progress_bar': True}
return HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
def _configure_llm(self, model_name: str, temperature: float):
"""Configure the Language Model with Groq"""
return ChatGroq(
model_name=model_name,
temperature=temperature,
max_tokens=4096,
streaming=True
)
def _initialize_vector_database(self, persist_directory: str = 'vector_db'):
"""Initialize the vector database"""
return Chroma(persist_directory=persist_directory, embedding_function=self.embeddings)
def _configure_retriever(self, retrieval_k: int):
"""Configure the document retriever"""
return self.vector_db.as_retriever(
search_kwargs={
"k": retrieval_k,
"search_type": "mmr",
"fetch_k": 20
}
)
def _create_conversational_retrieval_chain(self):
"""Create the conversational retrieval chain"""
template = """
You are a helpful AI assistant. Provide a precise and comprehensive answer
based on the context and chat history.
Context: {context}
Chat History: {chat_history}
Question: {question}
Helpful Answer:"""
prompt = ChatPromptTemplate.from_template(template)
return ConversationalRetrievalChain.from_llm(
llm=self.llm,
retriever=self.retriever,
memory=self.memory,
combine_docs_chain_kwargs={'prompt': prompt},
return_source_documents=True
)
def process_query(self, query: str) -> Dict[str, Any]:
"""Process the user query with multiple NLP techniques"""
# Advanced NLP Analysis
semantic_score = self.semantic_model.encode([query])[0]
sentiment_result = self.sentiment_analyzer(query)[0]
entities = self.ner_pipeline(query)
# RAG Query Processing
result = self.qa_chain({"question": query})
return {
"response": result['answer'],
"source_documents": result.get('source_documents', []),
"semantic_similarity": semantic_score.tolist(),
"sentiment": sentiment_result,
"named_entities": entities
}
def main():
# Page Configuration
st.set_page_config(
page_title="Advanced RAG Chatbot",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded"
)
# Sidebar Configuration
with st.sidebar:
st.header("πŸ”§ Chatbot Settings")
st.markdown("Customize your AI assistant's behavior")
# Model Configuration
embedding_model = st.selectbox(
"Embedding Model",
["BAAI/bge-large-en-v1.5", "sentence-transformers/all-MiniLM-L6-v2"]
)
temperature = st.slider("Creativity Level", 0.0, 1.0, 0.7, help="Higher values make responses more creative")
retrieval_k = st.slider("Context Depth", 1, 10, 5, help="Number of reference documents to retrieve")
# Additional Controls
st.divider()
reset_chat = st.button("πŸ”„ Reset Conversation")
# Initialize Chatbot
chatbot = AdvancedRAGChatbot(
embedding_model=embedding_model,
temperature=temperature,
retrieval_k=retrieval_k
)
# Main Chat Interface
st.title("πŸ€– Advanced RAG Chatbot")
# Two-column layout
col1, col2 = st.columns(2)
with col1:
st.header("Input")
# Chat input with placeholder
user_input = st.text_area(
"Ask your question",
placeholder="Enter your query here...",
height=250
)
# Submit button
submit_button = st.button("Send Query", type="primary")
with col2:
st.header("Response")
# Response container
if submit_button and user_input:
with st.spinner("Processing your query..."):
try:
response = chatbot.process_query(user_input)
# Bot Response
st.markdown("#### Bot's Answer")
st.write(response['response'])
# Sentiment Analysis
st.markdown("#### Sentiment Analysis")
sentiment = response['sentiment']
st.metric(
label="Sentiment",
value=sentiment['label'],
delta=f"{sentiment['score']:.2%}"
)
# Named Entities
st.markdown("#### Detected Entities")
for entity in response['named_entities']:
st.text(f"{entity['word']} ({entity['entity']})")
# Source Documents
if response['source_documents']:
st.markdown("#### Reference Documents")
for i, doc in enumerate(response['source_documents'], 1):
with st.expander(f"Document {i}"):
st.write(doc.page_content)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.info("Submit a query to see the AI's response")
if __name__ == "__main__":
main()