Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from langchain_community.vectorstores import FAISS | |
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings | |
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough | |
from langchain.memory import ConversationBufferMemory | |
from langchain_core.messages import get_buffer_string | |
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None) | |
db = FAISS.load_local("data_first_faiss_index", embedder, allow_dangerous_deserialization=True) | |
# docs = new_db.similarity_search(query) | |
nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "") | |
from operator import itemgetter | |
# available models names | |
# mixtral_8x7b | |
# llama2_13b | |
llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser() | |
initial_msg = ( | |
"Hello! I am a chatbot to help with any questions about Data First Company." | |
f"\nHow can I help you?" | |
) | |
context_prompt = ChatPromptTemplate.from_messages([ | |
('system', | |
"You are a chatbot, and you are helping customer with their inquries about Data First Company." | |
"Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer." | |
"Please help them with their question about the company. Remember that your job is to represent Data First company that create data solutions." | |
"Do not hallucinate any details, and make sure the knowledge base is not redundant." | |
"Please say you do not know if you do not know or you cannot find the information needed." | |
"\n\nQuestion: {question}\n\nContext: {context}"), | |
('user', "{question}" | |
)]) | |
chain = ( | |
{ | |
'context': db.as_retriever(search_type="similarity"), | |
'question': (lambda x:x) | |
} | |
| context_prompt | |
# | RPrint() | |
| llm | |
| StrOutputParser() | |
) | |
conv_chain = ( | |
context_prompt | |
# | RPrint() | |
| llm | |
| StrOutputParser() | |
) | |
def chat_gen(message, history, return_buffer=True): | |
buffer = "" | |
doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2}) | |
retrieved_docs = doc_retriever.invoke(message) | |
print(len(retrieved_docs)) | |
print(retrieved_docs) | |
if len(retrieved_docs) > 0: | |
state = { | |
'question': message, | |
'context': retrieved_docs | |
} | |
for token in conv_chain.stream(state): | |
buffer += token | |
yield buffer | |
else: | |
passage = "I am sorry. I do not have relevant information to answer the question. Please try another question." | |
buffer += passage | |
yield buffer if return_buffer else passage | |
chatbot = gr.Chatbot(value = [[None, initial_msg]]) | |
iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue() | |
iface.launch() |