Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,49 +1,131 @@
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
from langchain_core.prompts import ChatPromptTemplate
|
3 |
from langchain_groq import ChatGroq
|
4 |
-
from transformers import pipeline
|
5 |
-
import torch
|
6 |
-
from groq import Groq
|
7 |
from langchain_community.vectorstores import Chroma
|
|
|
8 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
9 |
from langchain.chains import RetrievalQA
|
10 |
-
from
|
11 |
-
from langchain.
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
GROQ_API_KEY = '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'
|
16 |
-
|
17 |
-
groq_client = Groq(api_key=GROQ_API_KEY)
|
18 |
-
|
19 |
-
def configure_groq_llm(model_name="llama-3.3-70b-versatile", temperature=0.7, max_tokens=2048):
|
20 |
-
return ChatGroq(groq_api_key=GROQ_API_KEY, model_name=model_name, temperature=temperature, max_tokens=max_tokens)
|
21 |
-
|
22 |
-
def get_embeddings(model_name="BAAI/bge-base-en"):
|
23 |
-
encode_kwargs = {'normalize_embeddings': True}
|
24 |
-
return HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
|
25 |
-
|
26 |
-
def create_llama_prompt():
|
27 |
-
template = """ Use the following context to answer the question: Context: {context} Question: {question} Helpful Answer:"""
|
28 |
-
return PromptTemplate(template=template, input_variables=["context", "question"])
|
29 |
-
|
30 |
-
embeddings = get_embeddings()
|
31 |
-
llm = configure_groq_llm()
|
32 |
-
vector_db = Chroma(persist_directory='db', embedding_function=embeddings)
|
33 |
-
retriever = vector_db.as_retriever(search_kwargs={"k": 5})
|
34 |
-
prompt = create_llama_prompt()
|
35 |
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
def
|
39 |
-
st.
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
if user_input:
|
42 |
-
|
43 |
-
response =
|
44 |
-
|
45 |
-
|
46 |
-
st.
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
if __name__ == "__main__":
|
49 |
-
|
|
|
1 |
import os
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
from typing import List, Dict, Any
|
5 |
from langchain_core.prompts import ChatPromptTemplate
|
6 |
from langchain_groq import ChatGroq
|
|
|
|
|
|
|
7 |
from langchain_community.vectorstores import Chroma
|
8 |
+
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
9 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
10 |
from langchain.chains import RetrievalQA
|
11 |
+
from langchain.memory import ConversationBufferMemory
|
12 |
+
from langchain.chains import ConversationalRetrievalChain
|
13 |
+
from transformers import pipeline
|
14 |
+
from sentence_transformers import SentenceTransformer
|
15 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
class AdvancedRAGChatbot:
|
18 |
+
def __init__(self,
|
19 |
+
embedding_model: str = "BAAI/bge-large-en-v1.5",
|
20 |
+
llm_model: str = "llama-3.3-70b-versatile",
|
21 |
+
temperature: float = 0.7,
|
22 |
+
retrieval_k: int = 5):
|
23 |
+
self.embeddings = self._configure_embeddings(embedding_model)
|
24 |
+
self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
25 |
+
self.sentiment_analyzer = pipeline("sentiment-analysis")
|
26 |
+
self.ner_pipeline = pipeline("ner", aggregation_strategy="simple")
|
27 |
+
self.llm = self._configure_llm(llm_model, temperature)
|
28 |
+
self.vector_db = self._initialize_vector_database()
|
29 |
+
self.retriever = self._configure_retriever(retrieval_k)
|
30 |
+
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
31 |
+
self.qa_chain = self._create_conversational_retrieval_chain()
|
32 |
+
|
33 |
+
def _configure_embeddings(self, model_name: str):
|
34 |
+
encode_kwargs = {'normalize_embeddings': True, 'show_progress_bar': True}
|
35 |
+
return HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
|
36 |
+
|
37 |
+
def _configure_llm(self, model_name: str, temperature: float):
|
38 |
+
return ChatGroq(
|
39 |
+
model_name=model_name,
|
40 |
+
temperature=temperature,
|
41 |
+
max_tokens=4096,
|
42 |
+
streaming=True
|
43 |
+
)
|
44 |
+
|
45 |
+
def _initialize_vector_database(self, persist_directory: str = 'vector_db'):
|
46 |
+
return Chroma(persist_directory=persist_directory, embedding_function=self.embeddings)
|
47 |
+
|
48 |
+
def _configure_retriever(self, retrieval_k: int):
|
49 |
+
return self.vector_db.as_retriever(k=retrieval_k, search_type="mmr", fetch_k=20)
|
50 |
+
|
51 |
+
def _create_conversational_retrieval_chain(self):
|
52 |
+
template = """
|
53 |
+
You are a helpful AI assistant. Use the following context and chat history to provide a precise answer.
|
54 |
+
|
55 |
+
Context: {context}
|
56 |
+
Chat History: {chat_history}
|
57 |
+
Question: {question}
|
58 |
+
|
59 |
+
Helpful Answer:"""
|
60 |
+
|
61 |
+
prompt = ChatPromptTemplate.from_template(template)
|
62 |
+
return ConversationalRetrievalChain.from_llm(
|
63 |
+
llm=self.llm,
|
64 |
+
retriever=self.retriever,
|
65 |
+
memory=self.memory,
|
66 |
+
combine_docs_chain_kwargs={'prompt': prompt},
|
67 |
+
return_source_documents=True
|
68 |
+
)
|
69 |
+
|
70 |
+
def process_query(self, query: str) -> Dict[str, Any]:
|
71 |
+
semantic_score = self.semantic_model.encode([query])[0]
|
72 |
+
sentiment_result = self.sentiment_analyzer(query)[0]
|
73 |
+
entities = self.ner_pipeline(query)
|
74 |
+
result = self.qa_chain({"question": query})
|
75 |
+
|
76 |
+
response_data = {
|
77 |
+
"response": result['answer'],
|
78 |
+
"source_documents": result.get('source_documents', []),
|
79 |
+
"semantic_similarity": semantic_score.tolist(),
|
80 |
+
"sentiment": sentiment_result,
|
81 |
+
"named_entities": entities,
|
82 |
+
"contextual_information": result.get("source_documents", [])
|
83 |
+
}
|
84 |
+
return response_data
|
85 |
|
86 |
+
def main():
|
87 |
+
st.set_page_config(page_title="Advanced NLP RAG Chatbot", layout="wide", initial_sidebar_state="expanded")
|
88 |
+
st.title("🧠 Advanced NLP RAG Chatbot")
|
89 |
+
|
90 |
+
with st.sidebar:
|
91 |
+
st.header("Configuration")
|
92 |
+
embedding_model = st.selectbox(
|
93 |
+
"Embedding Model",
|
94 |
+
["BAAI/bge-large-en-v1.5", "sentence-transformers/all-MiniLM-L6-v2"]
|
95 |
+
)
|
96 |
+
temperature = st.slider("Model Temperature", 0.0, 1.0, 0.7)
|
97 |
+
retrieval_k = st.slider("Documents to Retrieve (k)", 1, 10, 5)
|
98 |
+
|
99 |
+
chatbot = AdvancedRAGChatbot(
|
100 |
+
embedding_model=embedding_model,
|
101 |
+
temperature=temperature,
|
102 |
+
retrieval_k=retrieval_k
|
103 |
+
)
|
104 |
+
|
105 |
+
st.markdown("### Chat with the AI Assistant")
|
106 |
+
query_col, response_col = st.columns(2)
|
107 |
+
|
108 |
+
with query_col:
|
109 |
+
user_input = st.text_area("Ask your question:", placeholder="Type your question here...", height=150)
|
110 |
+
|
111 |
if user_input:
|
112 |
+
with st.spinner("Processing your query..."):
|
113 |
+
response = chatbot.process_query(user_input)
|
114 |
+
|
115 |
+
with response_col:
|
116 |
+
st.markdown("### Bot Response")
|
117 |
+
st.write(response['response'])
|
118 |
+
|
119 |
+
st.markdown("### Sentiment Analysis")
|
120 |
+
st.write(f"Sentiment: {response['sentiment']['label']} ({response['sentiment']['score']:.2%})")
|
121 |
+
|
122 |
+
st.markdown("### Named Entities")
|
123 |
+
for entity in response['named_entities']:
|
124 |
+
st.write(f"- {entity['word']} ({entity['entity']})")
|
125 |
+
|
126 |
+
st.markdown("### Source Documents")
|
127 |
+
for doc in response['source_documents']:
|
128 |
+
st.text_area("Source Document", doc.page_content, height=100)
|
129 |
+
|
130 |
if __name__ == "__main__":
|
131 |
+
main()
|