File size: 2,979 Bytes
c55b65b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c783ba9
c55b65b
 
 
 
 
 
 
 
 
efb12ec
8b62fd7
8f6f846
c55b65b
06369a1
 
 
 
 
c55b65b
c055062
c55b65b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings

from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import get_buffer_string
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser


embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None)
db = FAISS.load_local("phuket_faiss", embedder, allow_dangerous_deserialization=True)

# docs = new_db.similarity_search(query)

nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "")


from operator import itemgetter


# available models names
# mixtral_8x7b
# llama2_13b
llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()

initial_msg = (
    "Hello! I am Roam Mate to help you with your travel!"
    f"\nHow can I help you?"
)

prompt_template = ChatPromptTemplate.from_messages([("system", """
### [INST] Instruction: Answer the question based on your knowledge about places in Thailand. You are Roam Mate which is a chat bot to help users with their travel and recommending places according to their reference. Here is context to help:
Document Retrieval:\n{context}\n
(Answer only from retrieval. Only cite sources that are used. Make your response conversational.)


### QUESTION:
{question} [/INST]
 """), ('user', '{question}')])

chain = (
    {
        'context': db.as_retriever(search_type="similarity", search_kwargs={"k": 10}),
        'question': (lambda x:x)
    }
    | prompt_template
    # | RPrint()
    | llm
    | StrOutputParser()
)

conv_chain = (
    prompt_template
    # | RPrint()
    | llm
    | StrOutputParser()
)

def chat_gen(message, history, return_buffer=True):
    buffer = ""

    doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2, "k": 5})
    retrieved_docs = doc_retriever.invoke(message)
    print(len(retrieved_docs))
    print(retrieved_docs)

    if len(retrieved_docs) > 0:
        state = {
                'question': message,
                'context': retrieved_docs
        }
        ai_msg = conv_chain.invoke(state)
        print(ai_msg)
        for token in ai_msg:
            buffer += token
            yield buffer
        
        # buffer += "I use the following websites data to generate the above answer: \n"
        # for doc in retrieved_docs:
        #     buffer += f"{doc['metadata']['source']}\n"
        
        
    else:
      passage = "I am sorry. I do not have relevant information to answer on that specific topic. Please try another question."
      buffer += passage
      yield buffer if return_buffer else passage


chatbot = gr.Chatbot(value = [[None, initial_msg]])
iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
iface.launch()