Spaces:
Sleeping
Sleeping
File size: 2,979 Bytes
c55b65b c783ba9 c55b65b efb12ec 8b62fd7 8f6f846 c55b65b 06369a1 c55b65b c055062 c55b65b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import get_buffer_string
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None)
db = FAISS.load_local("phuket_faiss", embedder, allow_dangerous_deserialization=True)
# docs = new_db.similarity_search(query)
nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "")
from operator import itemgetter
# available models names
# mixtral_8x7b
# llama2_13b
llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()
initial_msg = (
"Hello! I am Roam Mate to help you with your travel!"
f"\nHow can I help you?"
)
prompt_template = ChatPromptTemplate.from_messages([("system", """
### [INST] Instruction: Answer the question based on your knowledge about places in Thailand. You are Roam Mate which is a chat bot to help users with their travel and recommending places according to their reference. Here is context to help:
Document Retrieval:\n{context}\n
(Answer only from retrieval. Only cite sources that are used. Make your response conversational.)
### QUESTION:
{question} [/INST]
"""), ('user', '{question}')])
chain = (
{
'context': db.as_retriever(search_type="similarity", search_kwargs={"k": 10}),
'question': (lambda x:x)
}
| prompt_template
# | RPrint()
| llm
| StrOutputParser()
)
conv_chain = (
prompt_template
# | RPrint()
| llm
| StrOutputParser()
)
def chat_gen(message, history, return_buffer=True):
buffer = ""
doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2, "k": 5})
retrieved_docs = doc_retriever.invoke(message)
print(len(retrieved_docs))
print(retrieved_docs)
if len(retrieved_docs) > 0:
state = {
'question': message,
'context': retrieved_docs
}
ai_msg = conv_chain.invoke(state)
print(ai_msg)
for token in ai_msg:
buffer += token
yield buffer
# buffer += "I use the following websites data to generate the above answer: \n"
# for doc in retrieved_docs:
# buffer += f"{doc['metadata']['source']}\n"
else:
passage = "I am sorry. I do not have relevant information to answer on that specific topic. Please try another question."
buffer += passage
yield buffer if return_buffer else passage
chatbot = gr.Chatbot(value = [[None, initial_msg]])
iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
iface.launch() |