Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from langchain_community.vectorstores import FAISS | |
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings | |
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough | |
from langchain.memory import ConversationBufferMemory | |
from langchain_core.messages import get_buffer_string | |
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None) | |
db = FAISS.load_local("phuket_faiss", embedder, allow_dangerous_deserialization=True) | |
# docs = new_db.similarity_search(query) | |
nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "") | |
from operator import itemgetter | |
# available models names | |
# mixtral_8x7b | |
# llama2_13b | |
llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser() | |
initial_msg = ( | |
"Hello! I am Roam Mate to help you with your travel!" | |
f"\nHow can I help you?" | |
) | |
prompt_template = ChatPromptTemplate.from_messages([("system", """ | |
### [INST] Instruction: Answer the question based on your knowledge about places in Thailand. You are Roam Mate which is a chat bot to help users with their travel and recommending places according to their reference. Here is context to help: | |
Document Retrieval:\n{context}\n | |
(Answer only from retrieval. Only cite sources that are used. Make your response conversational.) | |
### QUESTION: | |
{question} [/INST] | |
"""), ('user', '{question}')]) | |
chain = ( | |
{ | |
'context': db.as_retriever(search_type="similarity", search_kwargs={"k": 10}), | |
'question': (lambda x:x) | |
} | |
| prompt_template | |
# | RPrint() | |
| llm | |
| StrOutputParser() | |
) | |
conv_chain = ( | |
prompt_template | |
# | RPrint() | |
| llm | |
| StrOutputParser() | |
) | |
def chat_gen(message, history, return_buffer=True): | |
buffer = "" | |
doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2, "k": 5}) | |
retrieved_docs = doc_retriever.invoke(message) | |
print(len(retrieved_docs)) | |
print(retrieved_docs) | |
if len(retrieved_docs) > 0: | |
state = { | |
'question': message, | |
'context': retrieved_docs | |
} | |
ai_msg = conv_chain.invoke(state) | |
print(ai_msg) | |
for token in ai_msg: | |
buffer += token | |
yield buffer | |
# buffer += "I use the following websites data to generate the above answer: \n" | |
# for doc in retrieved_docs: | |
# buffer += f"{doc['metadata']['source']}\n" | |
else: | |
passage = "I am sorry. I do not have relevant information to answer on that specific topic. Please try another question." | |
buffer += passage | |
yield buffer if return_buffer else passage | |
chatbot = gr.Chatbot(value = [[None, initial_msg]]) | |
iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue() | |
iface.launch() |