Spaces:
Running
Running
Update custom_utils.py
Browse files- custom_utils.py +95 -108
custom_utils.py
CHANGED
@@ -12,6 +12,14 @@ from pymongo.mongo_client import MongoClient
|
|
12 |
DB_NAME = "airbnb_dataset"
|
13 |
COLLECTION_NAME = "listings_reviews"
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def rag_ingestion(collection):
|
16 |
dataset = load_dataset("MongoDB/airbnb_embeddings", streaming=True, split="train")
|
17 |
dataset_df = pd.DataFrame(dataset)
|
@@ -20,6 +28,54 @@ def rag_ingestion(collection):
|
|
20 |
collection.insert_many(listings)
|
21 |
# Manually create a vector search index (in free tier, this feature is not available via SDK)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def process_records(data_frame):
|
24 |
records = data_frame.to_dict(orient="records")
|
25 |
|
@@ -36,31 +92,13 @@ def process_records(data_frame):
|
|
36 |
record[key] = None
|
37 |
|
38 |
try:
|
39 |
-
print("111")
|
40 |
# Convert each dictionary to a Listing instance
|
41 |
return [Listing(**record).dict() for record in records]
|
42 |
except ValidationError as e:
|
43 |
print("Validation error:", e)
|
44 |
return []
|
45 |
|
46 |
-
def
|
47 |
-
"""Generate an embedding for the given text using OpenAI's API."""
|
48 |
-
|
49 |
-
# Check for valid input
|
50 |
-
if not text or not isinstance(text, str):
|
51 |
-
return None
|
52 |
-
|
53 |
-
try:
|
54 |
-
# Call OpenAI API to get the embedding
|
55 |
-
embedding = openai.embeddings.create(
|
56 |
-
input=text,
|
57 |
-
model="text-embedding-3-small", dimensions=1536).data[0].embedding
|
58 |
-
return embedding
|
59 |
-
except Exception as e:
|
60 |
-
print(f"Error in get_embedding: {e}")
|
61 |
-
return None
|
62 |
-
|
63 |
-
def vector_search_with_filter(user_query, db, collection, additional_stages=[], vector_index="vector_index_text"):
|
64 |
"""
|
65 |
Perform a vector search in the MongoDB collection based on the user query.
|
66 |
|
@@ -75,7 +113,7 @@ def vector_search_with_filter(user_query, db, collection, additional_stages=[],
|
|
75 |
"""
|
76 |
|
77 |
# Generate embedding for the user query
|
78 |
-
query_embedding = get_embedding(user_query)
|
79 |
|
80 |
if query_embedding is None:
|
81 |
return "Invalid query or embedding generation failed."
|
@@ -83,21 +121,14 @@ def vector_search_with_filter(user_query, db, collection, additional_stages=[],
|
|
83 |
# Define the vector search stage
|
84 |
vector_search_stage = {
|
85 |
"$vectorSearch": {
|
86 |
-
"index": vector_index,
|
87 |
-
"queryVector": query_embedding,
|
88 |
-
"path": "text_embeddings",
|
89 |
-
"numCandidates": 150,
|
90 |
-
"limit": 20,
|
91 |
-
"filter": {
|
92 |
-
"$and": [
|
93 |
-
{"accommodates": {"$gte": 2}},
|
94 |
-
{"bedrooms": {"$lte": 7}}
|
95 |
-
]
|
96 |
-
},
|
97 |
}
|
98 |
}
|
99 |
|
100 |
-
|
101 |
# Define the aggregate pipeline with the vector search stage and additional stages
|
102 |
pipeline = [vector_search_stage] + additional_stages
|
103 |
|
@@ -113,34 +144,14 @@ def vector_search_with_filter(user_query, db, collection, additional_stages=[],
|
|
113 |
verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
|
114 |
|
115 |
vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
|
116 |
-
millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
|
117 |
|
118 |
-
|
|
|
|
|
119 |
|
120 |
return list(results)
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
def connect_to_database():
|
126 |
-
"""Establish connection to the MongoDB."""
|
127 |
-
|
128 |
-
MONGO_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
|
129 |
-
|
130 |
-
if not MONGO_URI:
|
131 |
-
print("MONGO_URI not set in environment variables")
|
132 |
-
|
133 |
-
# gateway to interacting with a MongoDB database cluster
|
134 |
-
mongo_client = MongoClient(MONGO_URI, appname="advanced-rag")
|
135 |
-
print("Connection to MongoDB successful")
|
136 |
-
|
137 |
-
# Pymongo client of database and collection
|
138 |
-
db = mongo_client.get_database(DB_NAME)
|
139 |
-
collection = db.get_collection(COLLECTION_NAME)
|
140 |
-
|
141 |
-
return db, collection
|
142 |
-
|
143 |
-
def vector_search(user_query, db, collection, additional_stages=[], vector_index="vector_index_text"):
|
144 |
"""
|
145 |
Perform a vector search in the MongoDB collection based on the user query.
|
146 |
|
@@ -163,14 +174,21 @@ def vector_search(user_query, db, collection, additional_stages=[], vector_index
|
|
163 |
# Define the vector search stage
|
164 |
vector_search_stage = {
|
165 |
"$vectorSearch": {
|
166 |
-
"index": vector_index,
|
167 |
-
"queryVector": query_embedding,
|
168 |
-
"path": "text_embeddings",
|
169 |
-
"numCandidates": 150,
|
170 |
-
"limit": 20,
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
}
|
172 |
}
|
173 |
|
|
|
174 |
# Define the aggregate pipeline with the vector search stage and additional stages
|
175 |
pipeline = [vector_search_stage] + additional_stages
|
176 |
|
@@ -186,57 +204,26 @@ def vector_search(user_query, db, collection, additional_stages=[], vector_index
|
|
186 |
verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
|
187 |
|
188 |
vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
|
|
|
189 |
|
190 |
-
|
191 |
-
print(vector_search_explain)
|
192 |
-
#print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")
|
193 |
|
194 |
return list(results)
|
195 |
|
196 |
-
def
|
197 |
-
|
198 |
-
get_knowledge = vector_search(prompt, db, collection, stages, vector_index)
|
199 |
-
|
200 |
-
# Check if there are any results
|
201 |
-
if not get_knowledge:
|
202 |
-
return "No results found.", "No source information available."
|
203 |
-
|
204 |
-
# Convert search results into a list of SearchResultItem models
|
205 |
-
search_results_models = [
|
206 |
-
SearchResultItem(**result)
|
207 |
-
for result in get_knowledge
|
208 |
-
]
|
209 |
-
|
210 |
-
# Convert search results into a DataFrame for better rendering in Jupyter
|
211 |
-
search_results_df = pd.DataFrame([item.dict() for item in search_results_models])
|
212 |
|
213 |
-
|
|
|
|
|
214 |
|
215 |
-
def rag_inference(openai_api_key, prompt, search_results_df):
|
216 |
openai.api_key = openai_api_key
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
"content": "You are an AirBnB listing recommendation system."},
|
227 |
-
{
|
228 |
-
"role": "user",
|
229 |
-
"content": content
|
230 |
-
}
|
231 |
-
]
|
232 |
-
)
|
233 |
-
|
234 |
-
result = completion.choices[0].message.content
|
235 |
-
|
236 |
-
print("###")
|
237 |
-
print(f"- User Content:\n{content}\n")
|
238 |
-
print("###")
|
239 |
-
print(f"- Prompt Completion:\n{result}\n")
|
240 |
-
print("###")
|
241 |
-
|
242 |
-
return result
|
|
|
12 |
DB_NAME = "airbnb_dataset"
|
13 |
COLLECTION_NAME = "listings_reviews"
|
14 |
|
15 |
+
def connect_to_database():
|
16 |
+
MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
|
17 |
+
|
18 |
+
mongo_client = MongoClient(MONGODB_ATLAS_CLUSTER_URI, appname="advanced-rag")
|
19 |
+
db = mongo_client.get_database(DB_NAME)
|
20 |
+
collection = db.get_collection(COLLECTION_NAME)
|
21 |
+
return db, collection
|
22 |
+
|
23 |
def rag_ingestion(collection):
|
24 |
dataset = load_dataset("MongoDB/airbnb_embeddings", streaming=True, split="train")
|
25 |
dataset_df = pd.DataFrame(dataset)
|
|
|
28 |
collection.insert_many(listings)
|
29 |
# Manually create a vector search index (in free tier, this feature is not available via SDK)
|
30 |
|
31 |
+
def rag_retrieval(openai_api_key, prompt, db, collection, stages=[], vector_index="vector_index"):
|
32 |
+
# Assuming vector_search returns a list of dictionaries with keys 'title' and 'plot'
|
33 |
+
get_knowledge = vector_search(openai_api_key, prompt, db, collection, stages, vector_index)
|
34 |
+
|
35 |
+
# Check if there are any results
|
36 |
+
if not get_knowledge:
|
37 |
+
return "No results found.", "No source information available."
|
38 |
+
|
39 |
+
# Convert search results into a list of SearchResultItem models
|
40 |
+
search_results_models = [
|
41 |
+
SearchResultItem(**result)
|
42 |
+
for result in get_knowledge
|
43 |
+
]
|
44 |
+
|
45 |
+
# Convert search results into a DataFrame for better rendering in Jupyter
|
46 |
+
search_results_df = pd.DataFrame([item.dict() for item in search_results_models])
|
47 |
+
|
48 |
+
return search_results_df
|
49 |
+
|
50 |
+
def rag_inference(openai_api_key, prompt, search_results_df):
|
51 |
+
openai.api_key = openai_api_key
|
52 |
+
|
53 |
+
# Generate system response using OpenAI's completion
|
54 |
+
content = f"Answer this user question: {prompt} with the following context:\n{search_results_df}"
|
55 |
+
|
56 |
+
completion = openai.chat.completions.create(
|
57 |
+
model="gpt-4o",
|
58 |
+
messages=[
|
59 |
+
{
|
60 |
+
"role": "system",
|
61 |
+
"content": "You are an AirBnB listing recommendation system."},
|
62 |
+
{
|
63 |
+
"role": "user",
|
64 |
+
"content": content
|
65 |
+
}
|
66 |
+
]
|
67 |
+
)
|
68 |
+
|
69 |
+
result = completion.choices[0].message.content
|
70 |
+
|
71 |
+
print("###")
|
72 |
+
print(f"- User Content:\n{content}\n")
|
73 |
+
print("###")
|
74 |
+
print(f"- Prompt Completion:\n{result}\n")
|
75 |
+
print("###")
|
76 |
+
|
77 |
+
return result
|
78 |
+
|
79 |
def process_records(data_frame):
|
80 |
records = data_frame.to_dict(orient="records")
|
81 |
|
|
|
92 |
record[key] = None
|
93 |
|
94 |
try:
|
|
|
95 |
# Convert each dictionary to a Listing instance
|
96 |
return [Listing(**record).dict() for record in records]
|
97 |
except ValidationError as e:
|
98 |
print("Validation error:", e)
|
99 |
return []
|
100 |
|
101 |
+
def vector_search(openai_api_key, user_query, db, collection, additional_stages=[], vector_index="vector_index_text"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
"""
|
103 |
Perform a vector search in the MongoDB collection based on the user query.
|
104 |
|
|
|
113 |
"""
|
114 |
|
115 |
# Generate embedding for the user query
|
116 |
+
query_embedding = get_embedding(openai_api_key, user_query)
|
117 |
|
118 |
if query_embedding is None:
|
119 |
return "Invalid query or embedding generation failed."
|
|
|
121 |
# Define the vector search stage
|
122 |
vector_search_stage = {
|
123 |
"$vectorSearch": {
|
124 |
+
"index": vector_index, # specifies the index to use for the search
|
125 |
+
"queryVector": query_embedding, # the vector representing the query
|
126 |
+
"path": "text_embeddings", # field in the documents containing the vectors to search against
|
127 |
+
"numCandidates": 150, # number of candidate matches to consider
|
128 |
+
"limit": 20, # return top 20 matches
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
}
|
130 |
}
|
131 |
|
|
|
132 |
# Define the aggregate pipeline with the vector search stage and additional stages
|
133 |
pipeline = [vector_search_stage] + additional_stages
|
134 |
|
|
|
144 |
verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
|
145 |
|
146 |
vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
|
|
|
147 |
|
148 |
+
#millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
|
149 |
+
print(vector_search_explain)
|
150 |
+
#print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")
|
151 |
|
152 |
return list(results)
|
153 |
|
154 |
+
def vector_search_with_filter(user_query, db, collection, additional_stages=[], vector_index="vector_index_2"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
"""
|
156 |
Perform a vector search in the MongoDB collection based on the user query.
|
157 |
|
|
|
174 |
# Define the vector search stage
|
175 |
vector_search_stage = {
|
176 |
"$vectorSearch": {
|
177 |
+
"index": vector_index, # specifies the index to use for the search
|
178 |
+
"queryVector": query_embedding, # the vector representing the query
|
179 |
+
"path": "text_embeddings", # field in the documents containing the vectors to search against
|
180 |
+
"numCandidates": 150, # number of candidate matches to consider
|
181 |
+
"limit": 20, # return top 20 matches
|
182 |
+
"filter": {
|
183 |
+
"$and": [
|
184 |
+
{"accommodates": {"$gte": 2}},
|
185 |
+
{"bedrooms": {"$lte": 7}}
|
186 |
+
]
|
187 |
+
},
|
188 |
}
|
189 |
}
|
190 |
|
191 |
+
|
192 |
# Define the aggregate pipeline with the vector search stage and additional stages
|
193 |
pipeline = [vector_search_stage] + additional_stages
|
194 |
|
|
|
204 |
verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
|
205 |
|
206 |
vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
|
207 |
+
millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
|
208 |
|
209 |
+
print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")
|
|
|
|
|
210 |
|
211 |
return list(results)
|
212 |
|
213 |
+
def get_embedding(openai_api_key, text):
|
214 |
+
"""Generate an embedding for the given text using OpenAI's API."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
+
# Check for valid input
|
217 |
+
if not text or not isinstance(text, str):
|
218 |
+
return None
|
219 |
|
|
|
220 |
openai.api_key = openai_api_key
|
221 |
+
|
222 |
+
try:
|
223 |
+
embedding = openai.embeddings.create(
|
224 |
+
input=text,
|
225 |
+
model="text-embedding-3-small", dimensions=1536).data[0].embedding
|
226 |
+
return embedding
|
227 |
+
except Exception as e:
|
228 |
+
print(f"Error in get_embedding: {e}")
|
229 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|