Spaces:
Runtime error
Runtime error
Maksimov-Dmitry
commited on
Commit
·
d1a829e
1
Parent(s):
eb025bc
app
Browse files- .gitattributes +1 -0
- app.py +250 -0
- data/db/.lock +1 -0
- data/db/collection/hotels/storage.sqlite +3 -0
- data/db/meta.json +1 -0
- requirements.txt +6 -0
- src/__pycache__/prompts.cpython-310.pyc +0 -0
- src/__pycache__/prompts.cpython-39.pyc +0 -0
- src/__pycache__/rag.cpython-310.pyc +0 -0
- src/__pycache__/rag.cpython-39.pyc +0 -0
- src/__pycache__/retriever.cpython-310.pyc +0 -0
- src/__pycache__/streamlit_utils.cpython-310.pyc +0 -0
- src/__pycache__/streamlit_utils.cpython-39.pyc +0 -0
- src/create_vector_db.py +279 -0
- src/prompts.py +62 -0
- src/retriever.py +92 -0
- src/streamlit_utils.py +80 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.sqlite filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src import streamlit_utils
|
2 |
+
from src.prompts import AGENT_SYSTEM_PROMPT, AGENT_USER_PROMPT, RAG_USER_PROMPT, TRAVERSIALAI_USER_PROMPT
|
3 |
+
from src.retriever import Retriever
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
from langchain_openai import ChatOpenAI
|
8 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
9 |
+
from langchain.memory import ChatMessageHistory
|
10 |
+
import re
|
11 |
+
import requests
|
12 |
+
import os
|
13 |
+
from qdrant_client import QdrantClient
|
14 |
+
|
15 |
+
collection_name = 'hotels'
|
16 |
+
|
17 |
+
st.set_page_config(page_title="Hotels search chatbot", page_icon="⭐")
|
18 |
+
st.header('Hotels search chatbot')
|
19 |
+
st.write('[![view source code and description](https://img.shields.io/badge/view_source_code-gray?logo=github)](https://github.com/Maksimov-Dmitry/traversaal-ai-hackathon)')
|
20 |
+
st.write('Developed by [Dmitry Maksimov](https://www.linkedin.com/in/maksimov-dmitry/), maksimov.dmitry.m@gmail.com and [Ilya Dudnik](https://www.linkedin.com/in/ilia-dudnik-5b8018271/), ilia.dudnik@fau.de')
|
21 |
+
|
22 |
+
st.sidebar.header('Choose your preferences')
|
23 |
+
n_hotels = st.sidebar.number_input('Number of hotels', min_value=1, max_value=10, value=3)
|
24 |
+
|
25 |
+
|
26 |
+
@st.cache_resource
|
27 |
+
def get_db_client(path='data/db'):
|
28 |
+
client = QdrantClient(path=path)
|
29 |
+
return client
|
30 |
+
|
31 |
+
|
32 |
+
def add_new_info(chat_history, queries):
|
33 |
+
"""After the user has changed any parameters (city, price, rating), we notify the Agent about it.
|
34 |
+
The information is added to the chat history.
|
35 |
+
Args:
|
36 |
+
chat_history: history of the chat
|
37 |
+
queries (list): list of queries that the user has changed
|
38 |
+
"""
|
39 |
+
for query in queries:
|
40 |
+
chat_history.add_user_message(query)
|
41 |
+
chat_history.add_ai_message('Ok, got it!')
|
42 |
+
|
43 |
+
|
44 |
+
def check_params(params):
|
45 |
+
"""Check if the user has changed the parameters (city, price, rating).
|
46 |
+
If the user has changed the parameters, the corresponding queries are created.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
params (dict): dictionary with the parameters
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
list: list of queries that the user has changed
|
53 |
+
"""
|
54 |
+
changed_params = []
|
55 |
+
|
56 |
+
if 'prev_params' not in st.session_state:
|
57 |
+
st.session_state.prev_params = {'city': '<BLANK>', 'price': '<BLANK>', 'rating': '<BLANK>'}
|
58 |
+
|
59 |
+
if st.session_state.prev_params['city'] != params['city']:
|
60 |
+
changed_params.append(f'I want to find hotels in {params["city"]}' if params['city'] else 'I want to find hotels in any city')
|
61 |
+
|
62 |
+
if st.session_state.prev_params['price'] != params['price']:
|
63 |
+
changed_params.append(f'I want to find hotels in price range {params["price"]}' if params['price'] else 'I want to find hotels in any price range')
|
64 |
+
|
65 |
+
if st.session_state.prev_params['rating'] != params['rating']:
|
66 |
+
changed_params.append(f'I want to find hotels with rating greater than {params["rating"]}')
|
67 |
+
|
68 |
+
st.session_state.prev_params = params
|
69 |
+
|
70 |
+
return changed_params
|
71 |
+
|
72 |
+
|
73 |
+
def get_parameters(db_client):
|
74 |
+
"""Get the parameters from the user (city, price, rating),
|
75 |
+
The provided metadata (in case it was provided by the user) is used in the MixedRetrieval from Qdrant vector DB
|
76 |
+
"""
|
77 |
+
points, _ = db_client.scroll(
|
78 |
+
collection_name=collection_name,
|
79 |
+
limit=1e9,
|
80 |
+
with_payload=True,
|
81 |
+
with_vectors=False,
|
82 |
+
)
|
83 |
+
cities = ['Doest not matter'] + list(set([point.payload['city'] for point in points]))
|
84 |
+
city = st.sidebar.selectbox('City', list(cities), index=0)
|
85 |
+
if city == 'Doest not matter':
|
86 |
+
city = None
|
87 |
+
|
88 |
+
prices = ['Doest not matter'] + list(set([point.payload['price'] for point in points]))
|
89 |
+
price = st.sidebar.selectbox('Price', list(prices), index=0)
|
90 |
+
if price == 'Doest not matter':
|
91 |
+
price = None
|
92 |
+
|
93 |
+
rating = st.sidebar.slider('Min hotel rating', min_value=.0, max_value=5.0, value=4.5, step=.5)
|
94 |
+
return dict(city=city, price=price, rating=rating)
|
95 |
+
|
96 |
+
|
97 |
+
class HotelsSearchChatbot:
|
98 |
+
"""
|
99 |
+
This is the Agent class. It is responsible for the decision-making during conversation with the user.
|
100 |
+
Based on the user's query, the Agent decides which action to take and how to present result to the user.
|
101 |
+
"""
|
102 |
+
def __init__(self, db_client):
|
103 |
+
streamlit_utils.configure_api_keys()
|
104 |
+
|
105 |
+
self.llm_model = "gpt-4-1106-preview"
|
106 |
+
self.temperature = 0.6
|
107 |
+
|
108 |
+
self.embeedings_model = "text-embedding-3-large"
|
109 |
+
self.rerank_model = 'rerank-multilingual-v2.0'
|
110 |
+
|
111 |
+
self.ares_api_key = os.environ.get("ARES_API_KEY")
|
112 |
+
self.db_client = db_client
|
113 |
+
|
114 |
+
def _traversialai(self, query):
|
115 |
+
"""Acquiring information from the internet using the Traversaal.ai.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
query (str): search query
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
str: information from the internet based on the query
|
122 |
+
"""
|
123 |
+
url = "https://api-ares.traversaal.ai/live/predict"
|
124 |
+
|
125 |
+
payload = {"query": [query]}
|
126 |
+
headers = {
|
127 |
+
"x-api-key": self.ares_api_key,
|
128 |
+
"content-type": "application/json"
|
129 |
+
}
|
130 |
+
|
131 |
+
response = requests.post(url, json=payload, headers=headers)
|
132 |
+
try:
|
133 |
+
return response.json()['data']['response_text']
|
134 |
+
except:
|
135 |
+
return None
|
136 |
+
|
137 |
+
def _get_action(self, text):
|
138 |
+
"""Parse (read) the action and the action input from the response of the Agent
|
139 |
+
(after he made a decision what to do).
|
140 |
+
'action' and 'action_input' indicate whether we need to query additional tools
|
141 |
+
(vector DB, Traversaal AI) and how.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
text (str): response of the Agent, which contains the action and the action input
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
tuple: action, action input
|
148 |
+
"""
|
149 |
+
action_pattern = r"Action:\s*(.*)\n"
|
150 |
+
action_input_pattern = r"Action Input:\s*(.*)"
|
151 |
+
|
152 |
+
action_match = re.search(action_pattern, text)
|
153 |
+
action_input_match = re.search(action_input_pattern, text)
|
154 |
+
|
155 |
+
action = action_match.group(1) if action_match else None
|
156 |
+
action_input = action_input_match.group(1) if action_input_match else None
|
157 |
+
return action, action_input
|
158 |
+
|
159 |
+
def _make_action(self, action, action_input, retriever, chain, chat_history, config, retriever_params):
|
160 |
+
"""Take the action corresponding to 'action' and 'action input'. The 'action' can be one of the following:
|
161 |
+
'nothing' - Agent is capable of dealing on its own without use of additional tools,
|
162 |
+
'hotels_data_base' - Agent decides to get the information from the hotels vector DB,
|
163 |
+
'ares_api' - Agent requires additional information from the internet using the Traversaal.ai.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
action (str): action to make
|
167 |
+
action_input (str): action input (formulated by Agent search query)
|
168 |
+
retriever (Retriever): Retriever object
|
169 |
+
chain (Chain): Chain object
|
170 |
+
chat_history (ChatMessageHistory): history of the chat
|
171 |
+
config (dict): handlers for a LangChain invoke method
|
172 |
+
retriever_params (dict): parameters for the Retriever
|
173 |
+
"""
|
174 |
+
if action == 'nothing':
|
175 |
+
st.markdown(action_input)
|
176 |
+
return action_input
|
177 |
+
|
178 |
+
if action == 'hotels_data_base':
|
179 |
+
context = retriever(action_input, top_k=n_hotels, **retriever_params)
|
180 |
+
chat_history.add_user_message(RAG_USER_PROMPT.format(context=context, query=action_input))
|
181 |
+
response = chain.invoke({"messages": chat_history.messages}, config)
|
182 |
+
chat_history.messages.pop()
|
183 |
+
return response.content
|
184 |
+
|
185 |
+
if action == 'ares_api':
|
186 |
+
context = self._traversialai(action_input)
|
187 |
+
chat_history.add_user_message(TRAVERSIALAI_USER_PROMPT.format(context=context, query=action_input))
|
188 |
+
response = chain.invoke({"messages": chat_history.messages}, config)
|
189 |
+
chat_history.messages.pop()
|
190 |
+
return response.content
|
191 |
+
|
192 |
+
return None
|
193 |
+
|
194 |
+
@st.cache_resource
|
195 |
+
def setup_chain(_self):
|
196 |
+
retriever = Retriever(embedding_model=_self.embeedings_model, llm_model=_self.llm_model,
|
197 |
+
rerank_model=_self.rerank_model, db_client=_self.db_client, db_collection=collection_name)
|
198 |
+
|
199 |
+
chat_history = ChatMessageHistory()
|
200 |
+
prompt = ChatPromptTemplate.from_messages(
|
201 |
+
[
|
202 |
+
(
|
203 |
+
"system",
|
204 |
+
AGENT_SYSTEM_PROMPT,
|
205 |
+
),
|
206 |
+
MessagesPlaceholder(variable_name="messages"),
|
207 |
+
]
|
208 |
+
)
|
209 |
+
chat = ChatOpenAI(model=_self.llm_model, temperature=_self.temperature, streaming=True)
|
210 |
+
chain = prompt | chat
|
211 |
+
|
212 |
+
return chain, chat_history, retriever
|
213 |
+
|
214 |
+
@streamlit_utils.enable_chat_history
|
215 |
+
def main(self, params):
|
216 |
+
chain, chat_history, retriever = self.setup_chain()
|
217 |
+
user_query = st.chat_input(placeholder="Ask me anything!")
|
218 |
+
if user_query:
|
219 |
+
streamlit_utils.display_msg(user_query, 'user')
|
220 |
+
|
221 |
+
# add new info to the chat history
|
222 |
+
queries = check_params(params)
|
223 |
+
add_new_info(chat_history, queries)
|
224 |
+
|
225 |
+
# get the action and the action input based on the user's query
|
226 |
+
chat_history.add_user_message(AGENT_USER_PROMPT.format(input=user_query))
|
227 |
+
action_response = chain.invoke({"messages": chat_history.messages})
|
228 |
+
chat_history.messages.pop()
|
229 |
+
action, action_input = self._get_action(action_response.content)
|
230 |
+
|
231 |
+
with st.chat_message("assistant"):
|
232 |
+
st_cb = streamlit_utils.StreamHandler(st.empty())
|
233 |
+
|
234 |
+
# create response on the user's query
|
235 |
+
response = self._make_action(action, action_input,
|
236 |
+
retriever, chain, chat_history, {"callbacks": [st_cb]}, params)
|
237 |
+
chat_history.add_user_message(user_query)
|
238 |
+
if response is None:
|
239 |
+
response = 'Sorry, I cannot help you with it. Could you rephrase your question?'
|
240 |
+
st.markdown(response)
|
241 |
+
|
242 |
+
chat_history.add_ai_message(response)
|
243 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
244 |
+
|
245 |
+
|
246 |
+
if __name__ == "__main__":
|
247 |
+
db_client = get_db_client()
|
248 |
+
params = get_parameters(db_client)
|
249 |
+
obj = HotelsSearchChatbot(db_client)
|
250 |
+
obj.main(params)
|
data/db/.lock
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
tmp lock file
|
data/db/collection/hotels/storage.sqlite
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:deb2004afca01078aacc8036779b783e47f7f9c52d440a517b32eb81b892af97
|
3 |
+
size 4726784
|
data/db/meta.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"collections": {"hotels": {"vectors": {"size": 3072, "distance": "Cosine", "hnsw_config": null, "quantization_config": null, "on_disk": null}, "shard_number": null, "sharding_method": null, "replication_factor": null, "write_consistency_factor": null, "on_disk_payload": null, "hnsw_config": null, "wal_config": null, "optimizers_config": null, "init_from": null, "quantization_config": null, "sparse_vectors": null}}, "aliases": {}}
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
langchain-community
|
3 |
+
langchain-openai
|
4 |
+
qdrant-client
|
5 |
+
openai
|
6 |
+
cohere
|
src/__pycache__/prompts.cpython-310.pyc
ADDED
Binary file (3.57 kB). View file
|
|
src/__pycache__/prompts.cpython-39.pyc
ADDED
Binary file (3.36 kB). View file
|
|
src/__pycache__/rag.cpython-310.pyc
ADDED
Binary file (2.6 kB). View file
|
|
src/__pycache__/rag.cpython-39.pyc
ADDED
Binary file (3.07 kB). View file
|
|
src/__pycache__/retriever.cpython-310.pyc
ADDED
Binary file (3.93 kB). View file
|
|
src/__pycache__/streamlit_utils.cpython-310.pyc
ADDED
Binary file (2.54 kB). View file
|
|
src/__pycache__/streamlit_utils.cpython-39.pyc
ADDED
Binary file (1.69 kB). View file
|
|
src/create_vector_db.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import click
|
2 |
+
from qdrant_client import QdrantClient, models
|
3 |
+
from openai import OpenAI
|
4 |
+
from tqdm import tqdm
|
5 |
+
import json
|
6 |
+
import requests
|
7 |
+
import os
|
8 |
+
from prompts import REVIEWS_SYSTEM_PROMPT, REVIEWS_USER_PROMPT
|
9 |
+
|
10 |
+
TRIPADVISOR_API_KEY = os.environ.get('TRIPADVISOR_API_KEY')
|
11 |
+
|
12 |
+
|
13 |
+
def save_json(data, path):
|
14 |
+
with open(path, "w") as outfile:
|
15 |
+
json.dump(data, outfile)
|
16 |
+
|
17 |
+
|
18 |
+
def get_df(dataset_path, is_hf):
|
19 |
+
if is_hf:
|
20 |
+
from datasets import load_dataset
|
21 |
+
dataset = load_dataset(dataset_path)
|
22 |
+
return dataset['train'].to_pandas()
|
23 |
+
else:
|
24 |
+
import pandas as pd
|
25 |
+
return pd.read_csv(dataset_path)
|
26 |
+
|
27 |
+
|
28 |
+
def _concat_reviews(df):
|
29 |
+
text = ''
|
30 |
+
for _, row in df.iterrows():
|
31 |
+
text += '\n'
|
32 |
+
if row.review_title:
|
33 |
+
text += '\nTitle:\n' + row.review_title
|
34 |
+
if row.review_text:
|
35 |
+
text += '\nReview:\n' + row.review_text
|
36 |
+
|
37 |
+
return text
|
38 |
+
|
39 |
+
|
40 |
+
def create_reviews_symmary(df, model, hotels, pos_rate=4.0, neg_rate=4.0, n_reviews=6):
|
41 |
+
"""Create a summary of reviews for each hotel, based on the most positive and most negative reviews.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
df (pd.DataFrame): hotels dataset
|
45 |
+
model (str): OpenAI model name
|
46 |
+
hotels (list): list of hotels to create summaries for
|
47 |
+
pos_rate (float): minimum positive rate, inclusive
|
48 |
+
neg_rate (float): maximum negative rate, exclusive
|
49 |
+
n_reviews (int): number of reviews to consider for each category
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
dict: hotel name -> reviews summary
|
53 |
+
"""
|
54 |
+
df['review_text_len'] = df.review_text.str.len().fillna(value=0)
|
55 |
+
df['review_title_len'] = df.review_title.str.len().fillna(value=0)
|
56 |
+
|
57 |
+
client = OpenAI()
|
58 |
+
hotels_reviews_summary = {}
|
59 |
+
for hotel in tqdm(hotels):
|
60 |
+
temp = df[df.hotel_name.eq(hotel)]
|
61 |
+
temp_pos = temp[temp.rate >= pos_rate].nlargest(n_reviews, 'review_text_len')
|
62 |
+
temp_neg = temp[temp.rate < neg_rate].nlargest(n_reviews, 'review_text_len')
|
63 |
+
if len(temp_pos) == 0 and len(temp_neg) == 0:
|
64 |
+
temp_pos = temp.nlargest(n_reviews, 'review_title_len')
|
65 |
+
|
66 |
+
text = _concat_reviews(temp_pos) + _concat_reviews(temp_neg)
|
67 |
+
|
68 |
+
if text:
|
69 |
+
response = client.chat.completions.create(
|
70 |
+
model=model,
|
71 |
+
messages=[
|
72 |
+
{"role": "system", "content": REVIEWS_SYSTEM_PROMPT},
|
73 |
+
{"role": "user", "content": REVIEWS_USER_PROMPT.format(text=text)},
|
74 |
+
]
|
75 |
+
)
|
76 |
+
hotels_reviews_summary[hotel] = response.choices[0].message.content
|
77 |
+
return hotels_reviews_summary
|
78 |
+
|
79 |
+
|
80 |
+
def _get_loc_id(hotel):
|
81 |
+
""" Given a hotel name, receive location id.
|
82 |
+
In order to get the hotel info, we need to get the location id first.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
hotel (str): hotel name
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
str: location id
|
89 |
+
"""
|
90 |
+
url = "https://api.content.tripadvisor.com/api/v1/location/search?key={key}&searchQuery={hotel}&category=hotels&language=en"
|
91 |
+
headers = {"accept": "application/json"}
|
92 |
+
|
93 |
+
response = requests.get(url.format(hotel=hotel, key=TRIPADVISOR_API_KEY), headers=headers)
|
94 |
+
try:
|
95 |
+
return response.json()['data'][0]['location_id']
|
96 |
+
except Exception as e:
|
97 |
+
print(f'{response.status_code=}')
|
98 |
+
print(f'{response.text=}')
|
99 |
+
print(f'Error: {e}')
|
100 |
+
return None
|
101 |
+
|
102 |
+
|
103 |
+
def get_hotel_info(hotel):
|
104 |
+
"""Get hotel info from TripAdvisor.
|
105 |
+
The following information is retrieved using the TripAdvisor API:
|
106 |
+
- rank
|
107 |
+
- ratings distributions
|
108 |
+
- subratings
|
109 |
+
- amenities
|
110 |
+
|
111 |
+
Args:
|
112 |
+
hotel (str): hotel name
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
dict: hotel info
|
116 |
+
"""
|
117 |
+
url = "https://api.content.tripadvisor.com/api/v1/location/{loc_id}/details?key={key}&language=en¤cy=USD"
|
118 |
+
headers = {"accept": "application/json"}
|
119 |
+
|
120 |
+
loc_id = _get_loc_id(hotel)
|
121 |
+
if loc_id is None:
|
122 |
+
return None
|
123 |
+
response = requests.get(url.format(loc_id=loc_id, key=TRIPADVISOR_API_KEY), headers=headers)
|
124 |
+
try:
|
125 |
+
response = response.json()
|
126 |
+
except Exception as e:
|
127 |
+
print(f'{response.status_code=}')
|
128 |
+
print(f'{response.text=}')
|
129 |
+
print(f'Error: {e}')
|
130 |
+
return None
|
131 |
+
rank = response['ranking_data'].get('ranking_string')
|
132 |
+
reviews_ratings = response.get('review_rating_count')
|
133 |
+
subratings = {}
|
134 |
+
for d in response['subratings']:
|
135 |
+
subratings[response['subratings'][d]['name']] = response['subratings'][d]['value']
|
136 |
+
amenities = response.get('amenities', [])
|
137 |
+
return dict(
|
138 |
+
rank=rank,
|
139 |
+
reviews_ratings=reviews_ratings,
|
140 |
+
subratings=subratings,
|
141 |
+
amenities=amenities,
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
def get_desc(hotel, data):
|
146 |
+
"""Create a text description of the hotel based on the retrieved data from TripAdvisor.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
hotel (str): hotel name
|
150 |
+
data (dict): hotel info
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
str: hotel text description
|
154 |
+
"""
|
155 |
+
rating = "Rating: "+str(data[hotel]['rank'])+". "
|
156 |
+
|
157 |
+
distr_ranks = "Rating distribution "
|
158 |
+
for key in data[hotel]['reviews_ratings'].keys():
|
159 |
+
distr_ranks += str(key) + ": " + str(data[hotel]['reviews_ratings'][key] + ", ")
|
160 |
+
distr_ranks = distr_ranks[:-2]+". "
|
161 |
+
|
162 |
+
sub_ranks = "Specific ratings: "
|
163 |
+
if 'rate_location' in data[hotel]['subratings'].keys():
|
164 |
+
sub_ranks += "Location " + data[hotel]['subratings']['rate_location'] + ", "
|
165 |
+
|
166 |
+
if 'rate_sleep' in data[hotel]['subratings'].keys():
|
167 |
+
sub_ranks += "Sleep " + data[hotel]['subratings']['rate_sleep'] + ", "
|
168 |
+
if 'rate_room' in data[hotel]['subratings'].keys():
|
169 |
+
sub_ranks += "Room " + data[hotel]['subratings']['rate_room'] + ", "
|
170 |
+
if 'rate_service' in data[hotel]['subratings'].keys():
|
171 |
+
sub_ranks += "Service " + data[hotel]['subratings']['rate_service'] + ", "
|
172 |
+
if 'rate_cleanliness' in data[hotel]['subratings'].keys():
|
173 |
+
sub_ranks += "Cleanliness " + data[hotel]['subratings']['rate_cleanliness']
|
174 |
+
sub_ranks += ". "
|
175 |
+
|
176 |
+
amenities = "Amenities available: "
|
177 |
+
for i in data[hotel]['amenities']:
|
178 |
+
amenities += str(i) + ", "
|
179 |
+
amenities = amenities[:-2] + "."
|
180 |
+
|
181 |
+
total_desc = rating + distr_ranks + sub_ranks + amenities
|
182 |
+
return total_desc
|
183 |
+
|
184 |
+
|
185 |
+
def get_payload(hotel, df):
|
186 |
+
"""Create a metadata which will be stored in the database.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
hotel (str): hotel name
|
190 |
+
df (pd.DataFrame): hotels dataset
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
dict: metadata
|
194 |
+
"""
|
195 |
+
temp = df[df.hotel_name.eq(hotel)]
|
196 |
+
rating = temp.rating_value.value_counts().index[0]
|
197 |
+
city = temp.locality.value_counts().index[0]
|
198 |
+
country = temp.country.value_counts().index[0]
|
199 |
+
price = temp.price_range.str.split(' ').str[0].value_counts().index[0]
|
200 |
+
return dict(
|
201 |
+
hotel_name=hotel,
|
202 |
+
rating=rating,
|
203 |
+
city=city,
|
204 |
+
country=country,
|
205 |
+
price=price
|
206 |
+
)
|
207 |
+
|
208 |
+
|
209 |
+
@click.command()
|
210 |
+
@click.option('--dataset-path', default='traversaal-ai-hackathon/hotel_datasets', help='Path to the dataset.')
|
211 |
+
@click.option('--is-hf', is_flag=True, default=True, help='Whether the dataset is in huggingface format, csv otherwise.')
|
212 |
+
@click.option('--db-path', default='data/db', help='Path to the output database.')
|
213 |
+
@click.option('--collection-name', default='hotels', help='Name of the collection in the database.')
|
214 |
+
@click.option('--embeddings-model', default='text-embedding-3-large', help='Name of the model to use for embeddings.')
|
215 |
+
@click.option('--embeddings-size', default=3072, help='Size of the embeddings.')
|
216 |
+
@click.option('--reviews-model', default='gpt-3.5-turbo-0125', help='Name of the model to use for reviews summary.')
|
217 |
+
def create_vector_db(dataset_path, is_hf, db_path, collection_name, embeddings_model, embeddings_size, reviews_model):
|
218 |
+
REVIEW_SUMMARIES_PATH = 'reviews_summary.json'
|
219 |
+
HOTELS_INFO_PATH = 'hotels_info.json'
|
220 |
+
|
221 |
+
df = get_df(dataset_path, is_hf)
|
222 |
+
|
223 |
+
# Create a collection if it does not exist and filter out hotels that are already in the collection
|
224 |
+
qdrant_client = QdrantClient(path=db_path)
|
225 |
+
if not qdrant_client.collection_exists(collection_name):
|
226 |
+
qdrant_client.create_collection(
|
227 |
+
collection_name=collection_name,
|
228 |
+
vectors_config=models.VectorParams(size=embeddings_size, distance=models.Distance.COSINE),
|
229 |
+
)
|
230 |
+
hotels = df.hotel_name.unique()
|
231 |
+
else:
|
232 |
+
docs, _ = qdrant_client.scroll(
|
233 |
+
collection_name=collection_name,
|
234 |
+
limit=1e9,
|
235 |
+
with_payload=True,
|
236 |
+
with_vectors=False,
|
237 |
+
)
|
238 |
+
hotels = set(df.hotel_name.unique()) - set([doc.payload['hotel_name'] for doc in docs])
|
239 |
+
if len(hotels) == 0:
|
240 |
+
return
|
241 |
+
|
242 |
+
# Create reviews summary using OpenAI
|
243 |
+
reviews_summary = create_reviews_symmary(df, reviews_model, hotels)
|
244 |
+
save_json(reviews_summary, REVIEW_SUMMARIES_PATH)
|
245 |
+
|
246 |
+
# Get hotel info from TripAdvisor
|
247 |
+
hotels_info = {}
|
248 |
+
for hotel in tqdm(hotels):
|
249 |
+
hotels_info[hotel] = get_hotel_info(hotel)
|
250 |
+
save_json(hotels_info, HOTELS_INFO_PATH)
|
251 |
+
|
252 |
+
# Create descriptions and payloads for each hotel
|
253 |
+
texts = []
|
254 |
+
payloads = []
|
255 |
+
for hotel in hotels:
|
256 |
+
trip_desc_hotel = get_desc(hotel, hotels_info)
|
257 |
+
review_hotel = reviews_summary.get(hotel)
|
258 |
+
payload = get_payload(hotel, df)
|
259 |
+
text = trip_desc_hotel if trip_desc_hotel else '' + '\n' + review_hotel if review_hotel else ''
|
260 |
+
payload['description'] = text
|
261 |
+
payloads.append(payload)
|
262 |
+
texts.append(text)
|
263 |
+
|
264 |
+
# Create description embeddings and upsert them to the database
|
265 |
+
openai_client = OpenAI()
|
266 |
+
embeddings = openai_client.embeddings.create(input=texts, model=embeddings_model)
|
267 |
+
points = [
|
268 |
+
models.PointStruct(
|
269 |
+
id=idx,
|
270 |
+
vector=data.embedding,
|
271 |
+
payload=payload,
|
272 |
+
)
|
273 |
+
for idx, (data, payload) in enumerate(zip(embeddings.data, payloads))
|
274 |
+
]
|
275 |
+
qdrant_client.upsert(collection_name, points)
|
276 |
+
|
277 |
+
|
278 |
+
if __name__ == '__main__':
|
279 |
+
create_vector_db()
|
src/prompts.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RAG_SYSTEM_PROMPT = "You are a helpful assistant, who recommends the hotels based only on my preferences."
|
2 |
+
|
3 |
+
RAG_CONTEXT_TEMPLATE = """
|
4 |
+
{id}: {hotel_name}
|
5 |
+
{description}
|
6 |
+
"""
|
7 |
+
|
8 |
+
RAG_USER_PROMPT = """
|
9 |
+
Here are the information about most relevant hotels to my query
|
10 |
+
---------------------
|
11 |
+
{context}
|
12 |
+
---------------------
|
13 |
+
Present these results to me and justify the ranking (explain why a hotel matches my preferences). Don't draw ANY conclusion and don't based on own knowledge.
|
14 |
+
Query: {query}
|
15 |
+
Answer:
|
16 |
+
"""
|
17 |
+
|
18 |
+
AGENT_USER_PROMPT = """
|
19 |
+
Answer the following question as best you can. You have access to the following tools:
|
20 |
+
|
21 |
+
hotels_data_base: A tool which present information about most relevant hotels based on the query. The information contains pros and cons of the hotel based on reviews, reviews ratings and ammenities. It is usefull when user want to get hotels recommendations. In this case Action Input should be query which will be complete and usefull to retrive the most relevant hotels.
|
22 |
+
ares_api: An API which performs real-time internet searches. It can be usefull than you need specific information about the hotel or the locataion or smth else from the internet. In this case Action Input should be query which will be complete and usefull to retrive the information from the Internet.
|
23 |
+
nothing: If you are sure you can answer the user's query without additional tools. In this case Action Input should be just an answer.
|
24 |
+
|
25 |
+
Use the following format:
|
26 |
+
|
27 |
+
Question: the input question you must answer
|
28 |
+
Thought: you should always think about what to do
|
29 |
+
Action: the action to take, should be one of [hotels_data_base, ares_api, nothing]
|
30 |
+
Action Input: the input to the action
|
31 |
+
|
32 |
+
Begin!
|
33 |
+
|
34 |
+
Question: {input}
|
35 |
+
Thought:
|
36 |
+
"""
|
37 |
+
|
38 |
+
AGENT_SYSTEM_PROMPT = "You are a helpful assistant for a hotel recommendation system based on my preferences. Answer all questions to the best of your ability."
|
39 |
+
|
40 |
+
REVIEWS_SYSTEM_PROMPT = "You are a helpful assistant. Your goal is to underpin the strong and the weak points (features, amenities). If you can't find strong or weak points, don't write ANYTHING about them. The information consists of hotel reviews, i.e. Title of the review and the Review itself."
|
41 |
+
REVIEWS_USER_PROMPT = """{text} Good Example:
|
42 |
+
### Strong Points:
|
43 |
+
- The hotel boasts a favorable location with sea views and proximity to Zeitinburnu train station.
|
44 |
+
- Upgraded rooms, fitness facilities, and the outdoor pool area are well-received.
|
45 |
+
- The staff, including specific individuals like Mr. Levent, Cihan, and Buse, have been commended for their service.
|
46 |
+
- Room cleanliness is frequently mentioned as a positive aspect.
|
47 |
+
|
48 |
+
### Weak Points:
|
49 |
+
- Inconsistency in customer service, with some guests reporting a lack of assistance with luggage and unfriendly reception.
|
50 |
+
- Miscommunication regarding room rates and issues with overcharges.
|
51 |
+
- Some guests have found the hotel's amenities, such as the narrow balcony and the pool's restrictive rules, to be lacking.
|
52 |
+
- A few guests reported cleanliness issues in the bathroom and concerns with room repairs.
|
53 |
+
"""
|
54 |
+
|
55 |
+
TRAVERSIALAI_USER_PROMPT = """
|
56 |
+
Based on the information retrived from the internet, answer the following question as best you can.
|
57 |
+
---------------------
|
58 |
+
{context}
|
59 |
+
---------------------
|
60 |
+
Query: {query}
|
61 |
+
Answer:
|
62 |
+
"""
|
src/retriever.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
import cohere
|
3 |
+
from qdrant_client import models
|
4 |
+
from src.prompts import RAG_CONTEXT_TEMPLATE
|
5 |
+
|
6 |
+
|
7 |
+
class Retriever:
|
8 |
+
"""Retriever class for retrieving documents from the database
|
9 |
+
For retrieving documents, the following steps are performed:
|
10 |
+
1. Create an embedding for the query
|
11 |
+
2. Get n documents from the database based on the query and filters (Mixed retrieval)
|
12 |
+
3. Rerank the documents based on the query and select top k documents, where k << n (ReRanking)
|
13 |
+
4. Create a context from the selected documents
|
14 |
+
"""
|
15 |
+
def __init__(self, embedding_model, llm_model, rerank_model, db_client, db_collection='hotels'):
|
16 |
+
self.db_collection = db_collection
|
17 |
+
self.db_client = db_client
|
18 |
+
self.rerank_model = rerank_model
|
19 |
+
self.openai_client = OpenAI()
|
20 |
+
self.co = cohere.Client()
|
21 |
+
self.embedding_model = embedding_model
|
22 |
+
self.llm_model = llm_model
|
23 |
+
self.max_retrieved_docs = 13
|
24 |
+
|
25 |
+
def _get_documents(self, query, top_k, city, price, rating):
|
26 |
+
"""Retrieve top n documents from the database based on the query and filters
|
27 |
+
|
28 |
+
Args:
|
29 |
+
query (str): query
|
30 |
+
top_k (int): number of documents to retrieve
|
31 |
+
city (str): city name
|
32 |
+
price (str): price range
|
33 |
+
rating (float): rating
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
list: list of documents
|
37 |
+
"""
|
38 |
+
embedding = self.openai_client.embeddings.create(input=query, model=self.embedding_model)
|
39 |
+
filtr = []
|
40 |
+
if city:
|
41 |
+
filtr.append(models.FieldCondition(key="city", match=models.MatchValue(value=city)))
|
42 |
+
if price:
|
43 |
+
filtr.append(models.FieldCondition(key="price", match=models.MatchValue(value=price)))
|
44 |
+
if rating:
|
45 |
+
filtr.append(models.FieldCondition(key="rating", range=models.Range(gte=rating)))
|
46 |
+
response = self.db_client.search(
|
47 |
+
collection_name=self.db_collection,
|
48 |
+
query_vector=embedding.data[0].embedding,
|
49 |
+
limit=top_k,
|
50 |
+
query_filter=models.Filter(
|
51 |
+
must=filtr
|
52 |
+
),
|
53 |
+
)
|
54 |
+
return response
|
55 |
+
|
56 |
+
def _get_context(self, docs):
|
57 |
+
"""Create a context from the retrieved documents
|
58 |
+
|
59 |
+
Args:
|
60 |
+
docs (list): list of documents
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
str: context
|
64 |
+
"""
|
65 |
+
context = ''
|
66 |
+
for i, doc in enumerate(docs, 1):
|
67 |
+
context += RAG_CONTEXT_TEMPLATE.format(id=i, hotel_name=doc.payload['hotel_name'], description=doc.payload['description'])
|
68 |
+
return context
|
69 |
+
|
70 |
+
def _reranker(self, docs, query, top_k):
|
71 |
+
"""Rerank the retrieved documents using Cohere based on the query and select top k documents
|
72 |
+
|
73 |
+
Args:
|
74 |
+
docs (list): list of documents
|
75 |
+
query (str): query
|
76 |
+
top_k (int): number of documents to select
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
list: list of reranked documents
|
80 |
+
"""
|
81 |
+
texts = [doc.payload['description'] for doc in docs]
|
82 |
+
rerank_hits = self.co.rerank(query=query, documents=texts, top_n=top_k, model=self.rerank_model)
|
83 |
+
result = [docs[hit.index] for hit in rerank_hits[:top_k]]
|
84 |
+
return result
|
85 |
+
|
86 |
+
def __call__(self, query, top_k=3, city=None, price=None, rating=None):
|
87 |
+
docs = self._get_documents(query, top_k=max(self.max_retrieved_docs, top_k), city=city, price=price, rating=rating)
|
88 |
+
if len(docs) == 0:
|
89 |
+
return 'There are no such hotels'
|
90 |
+
docs = self._reranker(docs, query, top_k)
|
91 |
+
context = self._get_context(docs)
|
92 |
+
return context
|
src/streamlit_utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
4 |
+
|
5 |
+
|
6 |
+
class StreamHandler(BaseCallbackHandler):
|
7 |
+
def __init__(self, container, initial_text=""):
|
8 |
+
self.container = container
|
9 |
+
self.text = initial_text
|
10 |
+
|
11 |
+
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
12 |
+
self.text += token
|
13 |
+
self.container.markdown(self.text)
|
14 |
+
|
15 |
+
|
16 |
+
def enable_chat_history(func):
|
17 |
+
if os.environ.get("OPENAI_API_KEY"):
|
18 |
+
|
19 |
+
# to clear chat history after swtching chatbot
|
20 |
+
current_page = func.__qualname__
|
21 |
+
if "current_page" not in st.session_state:
|
22 |
+
st.session_state["current_page"] = current_page
|
23 |
+
if st.session_state["current_page"] != current_page:
|
24 |
+
try:
|
25 |
+
st.cache_resource.clear()
|
26 |
+
del st.session_state["current_page"]
|
27 |
+
del st.session_state["messages"]
|
28 |
+
except:
|
29 |
+
pass
|
30 |
+
|
31 |
+
# to show chat history on ui
|
32 |
+
if "messages" not in st.session_state:
|
33 |
+
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
34 |
+
for msg in st.session_state["messages"]:
|
35 |
+
st.chat_message(msg["role"]).write(msg["content"])
|
36 |
+
|
37 |
+
def execute(*args, **kwargs):
|
38 |
+
func(*args, **kwargs)
|
39 |
+
return execute
|
40 |
+
|
41 |
+
|
42 |
+
def display_msg(msg, author):
|
43 |
+
"""Method to display message on the UI
|
44 |
+
|
45 |
+
Args:
|
46 |
+
msg (str): message to display
|
47 |
+
author (str): author of the message -user/assistant
|
48 |
+
"""
|
49 |
+
st.session_state.messages.append({"role": author, "content": msg})
|
50 |
+
st.chat_message(author).write(msg)
|
51 |
+
|
52 |
+
|
53 |
+
def configure_api_keys():
|
54 |
+
KEYS = ['OPENAI_API_KEY', 'CO_API_KEY', 'ARES_API_KEY']
|
55 |
+
st.sidebar.header('Api Keys Configuration')
|
56 |
+
st.markdown(
|
57 |
+
"""
|
58 |
+
<style>
|
59 |
+
[title="Show password text"] {
|
60 |
+
display: none;
|
61 |
+
}
|
62 |
+
</style>
|
63 |
+
""",
|
64 |
+
unsafe_allow_html=True,
|
65 |
+
)
|
66 |
+
for key in KEYS:
|
67 |
+
if key in os.environ:
|
68 |
+
st.session_state[key] = os.environ[key]
|
69 |
+
api_key = st.sidebar.text_input(
|
70 |
+
label=key,
|
71 |
+
type="password",
|
72 |
+
value=st.session_state[key] if key in st.session_state else '',
|
73 |
+
placeholder="..."
|
74 |
+
)
|
75 |
+
if api_key:
|
76 |
+
st.session_state[key] = api_key
|
77 |
+
os.environ[key] = api_key
|
78 |
+
else:
|
79 |
+
st.error(f"Please add your {key} to continue.")
|
80 |
+
st.stop()
|