Spaces:
Sleeping
Sleeping
File size: 7,218 Bytes
b1ac1a0 bde0120 b1ac1a0 bde0120 b1ac1a0 bde0120 b1ac1a0 bde0120 b1ac1a0 bde0120 b1ac1a0 bde0120 b1ac1a0 bde0120 b1ac1a0 bde0120 b1ac1a0 bde0120 b1ac1a0 bde0120 b1ac1a0 bde0120 b1ac1a0 bde0120 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import os
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
import pymongo
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import get_buffer_string
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.messages import HumanMessage
embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None)
db = FAISS.load_local("vms_faiss_index", embedder, allow_dangerous_deserialization=True)
# docs = new_db.similarity_search(query)
nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "")
def get_mongo_client(mongo_uri):
"""Establish connection to the MongoDB."""
try:
client = pymongo.MongoClient(mongo_uri)
print("Connection to MongoDB successful")
return client
except pymongo.errors.ConnectionFailure as e:
print(f"Connection failed: {e}")
return None
mongo_uri = os.environ.get('MyCluster_MONGO_URI')
if not mongo_uri:
print("MONGO_URI not set in environment variables")
mongo_client = get_mongo_client(mongo_uri)
DB_NAME="vms_courses"
COLLECTION_NAME="courses"
db = mongo_client[DB_NAME]
collection = db[COLLECTION_NAME]
ATLAS_VECTOR_SEARCH_INDEX_NAME = "vector_index"
vector_search = MongoDBAtlasVectorSearch.from_connection_string(
mongo_uri,
DB_NAME + "." + COLLECTION_NAME,
embedder,
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
)
llm = ChatNVIDIA(model="mixtral_8x7b")
retriever = vector_search.as_retriever(
search_type="similarity",
search_kwargs={"k": 12},
)
### Contextualize question ###
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)
### Answer question ###
qa_system_prompt = """You are a VMS assistant for helping students with their academic. \
Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer. \
Please help them with their question. Remember that your job is to represent Vicent Mary School of Science and Technology (VMS) at Assumption University. \
Do not hallucinate any details, and make sure the knowledge base is not redundant.\
If you don't know the answer, just say that you don't know. \
{context}"""
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
### Statefully manage chat history ###
store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
c_history = []
def chat_gen(message, history):
buffer = ""
ai_message = rag_chain.invoke({"input": message, "chat_history": c_history})
c_history.extend([HumanMessage(content=message), ai_message["answer"]])
print(c_history)
yield ai_message["answer"]
# for doc in ai_message["context"]:
# yield doc
initial_msg = (
"Hello! I am VMS bot here to help you with your academic issues!"
f"\nHow can I help you?"
)
chatbot = gr.Chatbot(value = [[None, initial_msg]], bubble_full_width=False)
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
try:
demo.launch(debug=True, share=True, show_api=False)
demo.close()
except Exception as e:
demo.close()
print(e)
raise e
# available models names
# mixtral_8x7b
# llama2_13b
# llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()
# initial_msg = (
# "Hello! I am VMS bot here to help you with your academic issues!"
# f"\nHow can I help you?"
# )
# context_prompt = ChatPromptTemplate.from_messages([
# ('system',
# "You are a VMS chatbot, and you are helping students with their academic issues."
# "Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer."
# "Please help them with their question. Remember that your job is to represent Vicent Mary School of Science and Technology (VMS) at Assumption University."
# "Do not hallucinate any details, and make sure the knowledge base is not redundant."
# "Please say you do not know if you do not know or you cannot find the information needed."
# "\n\nQuestion: {question}\n\nContext: {context}"),
# ('user', "{question}"
# )])
# chain = (
# {
# 'context': db.as_retriever(search_type="similarity"),
# 'question': (lambda x:x)
# }
# | context_prompt
# # | RPrint()
# | llm
# | StrOutputParser()
# )
# conv_chain = (
# context_prompt
# # | RPrint()
# | llm
# | StrOutputParser()
# )
# def chat_gen(message, history, return_buffer=True):
# buffer = ""
# doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2})
# retrieved_docs = doc_retriever.invoke(message)
# print(len(retrieved_docs))
# print(retrieved_docs)
# if len(retrieved_docs) > 0:
# state = {
# 'question': message,
# 'context': retrieved_docs
# }
# for token in conv_chain.stream(state):
# buffer += token
# yield buffer
# else:
# passage = "I am sorry. I do not have relevant information to answer on that specific topic. Please try another question."
# buffer += passage
# yield buffer if return_buffer else passage
# chatbot = gr.Chatbot(value = [[None, initial_msg]])
# iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
# iface.launch() |