File size: 7,218 Bytes
b1ac1a0
 
 
 
bde0120
b1ac1a0
bde0120
b1ac1a0
 
 
 
 
bde0120
b1ac1a0
bde0120
 
 
 
 
 
b1ac1a0
 
 
 
 
 
 
 
 
 
 
bde0120
 
 
 
 
 
 
 
 
b1ac1a0
bde0120
 
 
b1ac1a0
bde0120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1ac1a0
 
bde0120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1ac1a0
bde0120
 
 
 
 
 
 
 
 
 
 
 
 
b1ac1a0
bde0120
 
 
 
 
 
 
b1ac1a0
 
bde0120
 
 
b1ac1a0
 
bde0120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
import pymongo

from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import get_buffer_string
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.messages import HumanMessage


embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None)
db = FAISS.load_local("vms_faiss_index", embedder, allow_dangerous_deserialization=True)

# docs = new_db.similarity_search(query)

nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "")



def get_mongo_client(mongo_uri):
  """Establish connection to the MongoDB."""
  try:
    client = pymongo.MongoClient(mongo_uri)
    print("Connection to MongoDB successful")
    return client
  except pymongo.errors.ConnectionFailure as e:
    print(f"Connection failed: {e}")
    return None

mongo_uri = os.environ.get('MyCluster_MONGO_URI')
if not mongo_uri:
  print("MONGO_URI not set in environment variables")

mongo_client = get_mongo_client(mongo_uri)

DB_NAME="vms_courses"
COLLECTION_NAME="courses"

db = mongo_client[DB_NAME]
collection = db[COLLECTION_NAME]
ATLAS_VECTOR_SEARCH_INDEX_NAME = "vector_index"


vector_search = MongoDBAtlasVectorSearch.from_connection_string(
    mongo_uri,
    DB_NAME + "." + COLLECTION_NAME,
    embedder,
    index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
)


llm = ChatNVIDIA(model="mixtral_8x7b")

retriever = vector_search.as_retriever(
    search_type="similarity",
    search_kwargs={"k": 12},
)



### Contextualize question ###
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)
history_aware_retriever = create_history_aware_retriever(
    llm, retriever, contextualize_q_prompt
)


### Answer question ###
qa_system_prompt = """You are a VMS assistant for helping students with their academic. \
Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer. \
Please help them with their question. Remember that your job is to represent Vicent Mary School of Science and Technology (VMS) at Assumption University. \
Do not hallucinate any details, and make sure the knowledge base is not redundant.\
If you don't know the answer, just say that you don't know. \

{context}"""

qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", qa_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)


### Statefully manage chat history ###
store = {}


def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]


conversational_rag_chain = RunnableWithMessageHistory(
    rag_chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)

c_history = []

def chat_gen(message, history):
    buffer = ""

    ai_message = rag_chain.invoke({"input": message, "chat_history": c_history})
    c_history.extend([HumanMessage(content=message), ai_message["answer"]])
    print(c_history)
    yield ai_message["answer"]

    # for doc in ai_message["context"]:
    #     yield doc

initial_msg = (
    "Hello! I am VMS bot here to help you with your academic issues!"
    f"\nHow can I help you?"
)


chatbot = gr.Chatbot(value = [[None, initial_msg]], bubble_full_width=False)
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()

try:
    demo.launch(debug=True, share=True, show_api=False)
    demo.close()
except Exception as e:
    demo.close()
    print(e)
    raise e

# available models names
# mixtral_8x7b
# llama2_13b
# llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()

# initial_msg = (
#     "Hello! I am VMS bot here to help you with your academic issues!"
#     f"\nHow can I help you?"
# )

# context_prompt = ChatPromptTemplate.from_messages([
#     ('system',
#         "You are a VMS chatbot, and you are helping students with their academic issues."
#         "Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer."
#         "Please help them with their question. Remember that your job is to represent Vicent Mary School of Science and Technology (VMS) at Assumption University."
#         "Do not hallucinate any details, and make sure the knowledge base is not redundant."
#         "Please say you do not know if you do not know or you cannot find the information needed."
#         "\n\nQuestion: {question}\n\nContext: {context}"),
#         ('user', "{question}"
# )])

# chain = (
#     {
#         'context': db.as_retriever(search_type="similarity"),
#         'question': (lambda x:x)
#     }
#     | context_prompt
#     # | RPrint()
#     | llm
#     | StrOutputParser()
# )

# conv_chain = (
#     context_prompt
#     # | RPrint()
#     | llm
#     | StrOutputParser()
# )

# def chat_gen(message, history, return_buffer=True):
#     buffer = ""

#     doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2})
#     retrieved_docs = doc_retriever.invoke(message)
#     print(len(retrieved_docs))
#     print(retrieved_docs)

#     if len(retrieved_docs) > 0:
#         state = {
#                 'question': message,
#                 'context': retrieved_docs
#         }
#         for token in conv_chain.stream(state):
#             buffer += token
#             yield buffer
#     else:
#       passage = "I am sorry. I do not have relevant information to answer on that specific topic. Please try another question."
#       buffer += passage
#       yield buffer if return_buffer else passage


# chatbot = gr.Chatbot(value = [[None, initial_msg]])
# iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
# iface.launch()