HEHEBOIOG commited on
Commit
d047c3e
·
verified ·
1 Parent(s): 40eeec4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -39
app.py CHANGED
@@ -1,49 +1,131 @@
1
  import os
 
 
 
2
  from langchain_core.prompts import ChatPromptTemplate
3
  from langchain_groq import ChatGroq
4
- from transformers import pipeline
5
- import torch
6
- from groq import Groq
7
  from langchain_community.vectorstores import Chroma
 
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain.chains import RetrievalQA
10
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
11
- from langchain.prompts import PromptTemplate
12
- import streamlit as st
13
-
14
- # GROQ_API_KEY = os.getenv('GROQ_API_KEY')
15
- GROQ_API_KEY = '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'
16
-
17
- groq_client = Groq(api_key=GROQ_API_KEY)
18
-
19
- def configure_groq_llm(model_name="llama-3.3-70b-versatile", temperature=0.7, max_tokens=2048):
20
- return ChatGroq(groq_api_key=GROQ_API_KEY, model_name=model_name, temperature=temperature, max_tokens=max_tokens)
21
-
22
- def get_embeddings(model_name="BAAI/bge-base-en"):
23
- encode_kwargs = {'normalize_embeddings': True}
24
- return HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
25
-
26
- def create_llama_prompt():
27
- template = """ Use the following context to answer the question: Context: {context} Question: {question} Helpful Answer:"""
28
- return PromptTemplate(template=template, input_variables=["context", "question"])
29
-
30
- embeddings = get_embeddings()
31
- llm = configure_groq_llm()
32
- vector_db = Chroma(persist_directory='db', embedding_function=embeddings)
33
- retriever = vector_db.as_retriever(search_kwargs={"k": 5})
34
- prompt = create_llama_prompt()
35
 
36
- qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, chain_type_kwargs={"prompt": prompt}, return_source_documents=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- def groq_nlp_chatbot():
39
- st.title("Groq Llama 3.2 Chatbot")
40
- user_input = st.text_input("Your Question:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  if user_input:
42
- try:
43
- response = qa_chain.invoke(user_input)
44
- st.text_area("Bot's Response:", response['result'])
45
- except Exception as e:
46
- st.error(f"Error processing request: {e}")
47
-
 
 
 
 
 
 
 
 
 
 
 
 
48
  if __name__ == "__main__":
49
- groq_nlp_chatbot()
 
1
  import os
2
+ import streamlit as st
3
+ import torch
4
+ from typing import List, Dict, Any
5
  from langchain_core.prompts import ChatPromptTemplate
6
  from langchain_groq import ChatGroq
 
 
 
7
  from langchain_community.vectorstores import Chroma
8
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain.chains import RetrievalQA
11
+ from langchain.memory import ConversationBufferMemory
12
+ from langchain.chains import ConversationalRetrievalChain
13
+ from transformers import pipeline
14
+ from sentence_transformers import SentenceTransformer
15
+ import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ class AdvancedRAGChatbot:
18
+ def __init__(self,
19
+ embedding_model: str = "BAAI/bge-large-en-v1.5",
20
+ llm_model: str = "llama-3.3-70b-versatile",
21
+ temperature: float = 0.7,
22
+ retrieval_k: int = 5):
23
+ self.embeddings = self._configure_embeddings(embedding_model)
24
+ self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
25
+ self.sentiment_analyzer = pipeline("sentiment-analysis")
26
+ self.ner_pipeline = pipeline("ner", aggregation_strategy="simple")
27
+ self.llm = self._configure_llm(llm_model, temperature)
28
+ self.vector_db = self._initialize_vector_database()
29
+ self.retriever = self._configure_retriever(retrieval_k)
30
+ self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
31
+ self.qa_chain = self._create_conversational_retrieval_chain()
32
+
33
+ def _configure_embeddings(self, model_name: str):
34
+ encode_kwargs = {'normalize_embeddings': True, 'show_progress_bar': True}
35
+ return HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
36
+
37
+ def _configure_llm(self, model_name: str, temperature: float):
38
+ return ChatGroq(
39
+ model_name=model_name,
40
+ temperature=temperature,
41
+ max_tokens=4096,
42
+ streaming=True
43
+ )
44
+
45
+ def _initialize_vector_database(self, persist_directory: str = 'vector_db'):
46
+ return Chroma(persist_directory=persist_directory, embedding_function=self.embeddings)
47
+
48
+ def _configure_retriever(self, retrieval_k: int):
49
+ return self.vector_db.as_retriever(k=retrieval_k, search_type="mmr", fetch_k=20)
50
+
51
+ def _create_conversational_retrieval_chain(self):
52
+ template = """
53
+ You are a helpful AI assistant. Use the following context and chat history to provide a precise answer.
54
+
55
+ Context: {context}
56
+ Chat History: {chat_history}
57
+ Question: {question}
58
+
59
+ Helpful Answer:"""
60
+
61
+ prompt = ChatPromptTemplate.from_template(template)
62
+ return ConversationalRetrievalChain.from_llm(
63
+ llm=self.llm,
64
+ retriever=self.retriever,
65
+ memory=self.memory,
66
+ combine_docs_chain_kwargs={'prompt': prompt},
67
+ return_source_documents=True
68
+ )
69
+
70
+ def process_query(self, query: str) -> Dict[str, Any]:
71
+ semantic_score = self.semantic_model.encode([query])[0]
72
+ sentiment_result = self.sentiment_analyzer(query)[0]
73
+ entities = self.ner_pipeline(query)
74
+ result = self.qa_chain({"question": query})
75
+
76
+ response_data = {
77
+ "response": result['answer'],
78
+ "source_documents": result.get('source_documents', []),
79
+ "semantic_similarity": semantic_score.tolist(),
80
+ "sentiment": sentiment_result,
81
+ "named_entities": entities,
82
+ "contextual_information": result.get("source_documents", [])
83
+ }
84
+ return response_data
85
 
86
+ def main():
87
+ st.set_page_config(page_title="Advanced NLP RAG Chatbot", layout="wide", initial_sidebar_state="expanded")
88
+ st.title("🧠 Advanced NLP RAG Chatbot")
89
+
90
+ with st.sidebar:
91
+ st.header("Configuration")
92
+ embedding_model = st.selectbox(
93
+ "Embedding Model",
94
+ ["BAAI/bge-large-en-v1.5", "sentence-transformers/all-MiniLM-L6-v2"]
95
+ )
96
+ temperature = st.slider("Model Temperature", 0.0, 1.0, 0.7)
97
+ retrieval_k = st.slider("Documents to Retrieve (k)", 1, 10, 5)
98
+
99
+ chatbot = AdvancedRAGChatbot(
100
+ embedding_model=embedding_model,
101
+ temperature=temperature,
102
+ retrieval_k=retrieval_k
103
+ )
104
+
105
+ st.markdown("### Chat with the AI Assistant")
106
+ query_col, response_col = st.columns(2)
107
+
108
+ with query_col:
109
+ user_input = st.text_area("Ask your question:", placeholder="Type your question here...", height=150)
110
+
111
  if user_input:
112
+ with st.spinner("Processing your query..."):
113
+ response = chatbot.process_query(user_input)
114
+
115
+ with response_col:
116
+ st.markdown("### Bot Response")
117
+ st.write(response['response'])
118
+
119
+ st.markdown("### Sentiment Analysis")
120
+ st.write(f"Sentiment: {response['sentiment']['label']} ({response['sentiment']['score']:.2%})")
121
+
122
+ st.markdown("### Named Entities")
123
+ for entity in response['named_entities']:
124
+ st.write(f"- {entity['word']} ({entity['entity']})")
125
+
126
+ st.markdown("### Source Documents")
127
+ for doc in response['source_documents']:
128
+ st.text_area("Source Document", doc.page_content, height=100)
129
+
130
  if __name__ == "__main__":
131
+ main()