Spaces:
Sleeping
Sleeping
Kaung Myat Htet
commited on
Commit
•
e249f66
1
Parent(s):
77e2b81
initialize project
Browse files- .gitignore +3 -0
- app.py +84 -0
- requirements.txt +6 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.DS_Store
|
3 |
+
faiss_index/*
|
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from langchain_community.vectorstores import FAISS
|
3 |
+
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
|
4 |
+
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
|
5 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
6 |
+
from langchain_core.messages import HumanMessage
|
7 |
+
from langchain_core.runnables.history import RunnableWithMessageHistory
|
8 |
+
import gradio as gr
|
9 |
+
import time
|
10 |
+
|
11 |
+
|
12 |
+
embedder = NVIDIAEmbeddings(model="NV-Embed-QA", model_type=None)
|
13 |
+
db = FAISS.load_local("faiss_index", embedder, allow_dangerous_deserialization=True)
|
14 |
+
model = ChatNVIDIA(model="meta/llama3-70b-instruct")
|
15 |
+
|
16 |
+
retriever = db.as_retriever(search_kwargs={"k": 8})
|
17 |
+
|
18 |
+
retrieved_docs = retriever.invoke("Seafood restaurants in Phuket")
|
19 |
+
print(len(retrieved_docs))
|
20 |
+
for doc in retrieved_docs:
|
21 |
+
print(doc.metadata)
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def get_session_history(session_id):
|
26 |
+
return MongoDBChatMessageHistory(
|
27 |
+
session_id=session_id,
|
28 |
+
connection_string=os.environ["MONGODB_URI"],
|
29 |
+
database_name="tour_planner_db",
|
30 |
+
collection_name="chat_histories",
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
prompt = ChatPromptTemplate.from_messages(
|
35 |
+
[
|
36 |
+
("system", """
|
37 |
+
### [INST] Instruction: Answer the question based on your knowledge about places in Thailand. You are Roam Mate which is a chat bot to help users with their travel and recommending places according to their reference. Here is context to help:
|
38 |
+
Also provides your rationale for generating the places you are recommending.
|
39 |
+
Context:\n{context}\n
|
40 |
+
(Answer from retrieval if they are relevant to the question. Only cite sources that are used. Make your response conversational.)
|
41 |
+
|
42 |
+
|
43 |
+
### QUESTION:
|
44 |
+
{question} [/INST]
|
45 |
+
"""),
|
46 |
+
MessagesPlaceholder(variable_name="history")
|
47 |
+
]
|
48 |
+
)
|
49 |
+
|
50 |
+
runnable = prompt | model
|
51 |
+
|
52 |
+
runnable_with_history = RunnableWithMessageHistory(
|
53 |
+
runnable,
|
54 |
+
get_session_history,
|
55 |
+
input_messages_key="question",
|
56 |
+
history_messages_key="history",
|
57 |
+
)
|
58 |
+
|
59 |
+
initial_msg = (
|
60 |
+
"Hello! I am a chatbot to help with vacation."
|
61 |
+
f"\nHow can I help you?"
|
62 |
+
)
|
63 |
+
|
64 |
+
def chat_gen(message, history, session_id, return_buffer=True):
|
65 |
+
print(session_id)
|
66 |
+
buffer = ""
|
67 |
+
for token in runnable_with_history.stream(
|
68 |
+
{"question": message, "context": db.as_retriever(search_type="similarity", search_kwargs={"k": 5})},
|
69 |
+
config={"configurable": {"session_id": session_id}},
|
70 |
+
):
|
71 |
+
buffer += token.content
|
72 |
+
time.sleep(0.05)
|
73 |
+
yield buffer
|
74 |
+
|
75 |
+
with gr.Blocks(fill_height=True) as demo:
|
76 |
+
session_id = gr.Textbox("1", label="Session ID")
|
77 |
+
chatbot = gr.Chatbot(value = [[None, initial_msg]], bubble_full_width=True, scale=1)
|
78 |
+
gr.ChatInterface(chat_gen, chatbot=chatbot, additional_inputs=[session_id]).queue()
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
demo.launch()
|
83 |
+
|
84 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
langchain-community
|
3 |
+
faiss-cpu
|
4 |
+
flashrank
|
5 |
+
langchain-nvidia-ai-endpoints
|
6 |
+
langchain-mongodb
|