RoamMate / app.py
teddyllm's picture
Update app.py
8f6f846 verified
raw
history blame
2.98 kB
import os
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import get_buffer_string
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None)
db = FAISS.load_local("phuket_faiss", embedder, allow_dangerous_deserialization=True)
# docs = new_db.similarity_search(query)
nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "")
from operator import itemgetter
# available models names
# mixtral_8x7b
# llama2_13b
llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()
initial_msg = (
"Hello! I am Roam Mate to help you with your travel!"
f"\nHow can I help you?"
)
prompt_template = ChatPromptTemplate.from_messages([("system", """
### [INST] Instruction: Answer the question based on your knowledge about places in Thailand. You are Roam Mate which is a chat bot to help users with their travel and recommending places according to their reference. Here is context to help:
Document Retrieval:\n{context}\n
(Answer only from retrieval. Only cite sources that are used. Make your response conversational.)
### QUESTION:
{question} [/INST]
"""), ('user', '{question}')])
chain = (
{
'context': db.as_retriever(search_type="similarity", search_kwargs={"k": 10}),
'question': (lambda x:x)
}
| prompt_template
# | RPrint()
| llm
| StrOutputParser()
)
conv_chain = (
prompt_template
# | RPrint()
| llm
| StrOutputParser()
)
def chat_gen(message, history, return_buffer=True):
buffer = ""
doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2, "k": 5})
retrieved_docs = doc_retriever.invoke(message)
print(len(retrieved_docs))
print(retrieved_docs)
if len(retrieved_docs) > 0:
state = {
'question': message,
'context': retrieved_docs
}
ai_msg = conv_chain.invoke(state)
print(ai_msg)
for token in ai_msg:
buffer += token
yield buffer
# buffer += "I use the following websites data to generate the above answer: \n"
# for doc in retrieved_docs:
# buffer += f"{doc['metadata']['source']}\n"
else:
passage = "I am sorry. I do not have relevant information to answer on that specific topic. Please try another question."
buffer += passage
yield buffer if return_buffer else passage
chatbot = gr.Chatbot(value = [[None, initial_msg]])
iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
iface.launch()