bstraehle commited on
Commit
738445b
·
verified ·
1 Parent(s): 94639d3

Update custom_utils.py

Browse files
Files changed (1) hide show
  1. custom_utils.py +95 -108
custom_utils.py CHANGED
@@ -12,6 +12,14 @@ from pymongo.mongo_client import MongoClient
12
  DB_NAME = "airbnb_dataset"
13
  COLLECTION_NAME = "listings_reviews"
14
 
 
 
 
 
 
 
 
 
15
  def rag_ingestion(collection):
16
  dataset = load_dataset("MongoDB/airbnb_embeddings", streaming=True, split="train")
17
  dataset_df = pd.DataFrame(dataset)
@@ -20,6 +28,54 @@ def rag_ingestion(collection):
20
  collection.insert_many(listings)
21
  # Manually create a vector search index (in free tier, this feature is not available via SDK)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def process_records(data_frame):
24
  records = data_frame.to_dict(orient="records")
25
 
@@ -36,31 +92,13 @@ def process_records(data_frame):
36
  record[key] = None
37
 
38
  try:
39
- print("111")
40
  # Convert each dictionary to a Listing instance
41
  return [Listing(**record).dict() for record in records]
42
  except ValidationError as e:
43
  print("Validation error:", e)
44
  return []
45
 
46
- def get_embedding(text):
47
- """Generate an embedding for the given text using OpenAI's API."""
48
-
49
- # Check for valid input
50
- if not text or not isinstance(text, str):
51
- return None
52
-
53
- try:
54
- # Call OpenAI API to get the embedding
55
- embedding = openai.embeddings.create(
56
- input=text,
57
- model="text-embedding-3-small", dimensions=1536).data[0].embedding
58
- return embedding
59
- except Exception as e:
60
- print(f"Error in get_embedding: {e}")
61
- return None
62
-
63
- def vector_search_with_filter(user_query, db, collection, additional_stages=[], vector_index="vector_index_text"):
64
  """
65
  Perform a vector search in the MongoDB collection based on the user query.
66
 
@@ -75,7 +113,7 @@ def vector_search_with_filter(user_query, db, collection, additional_stages=[],
75
  """
76
 
77
  # Generate embedding for the user query
78
- query_embedding = get_embedding(user_query)
79
 
80
  if query_embedding is None:
81
  return "Invalid query or embedding generation failed."
@@ -83,21 +121,14 @@ def vector_search_with_filter(user_query, db, collection, additional_stages=[],
83
  # Define the vector search stage
84
  vector_search_stage = {
85
  "$vectorSearch": {
86
- "index": vector_index, # specifies the index to use for the search
87
- "queryVector": query_embedding, # the vector representing the query
88
- "path": "text_embeddings", # field in the documents containing the vectors to search against
89
- "numCandidates": 150, # number of candidate matches to consider
90
- "limit": 20, # return top 20 matches
91
- "filter": {
92
- "$and": [
93
- {"accommodates": {"$gte": 2}},
94
- {"bedrooms": {"$lte": 7}}
95
- ]
96
- },
97
  }
98
  }
99
 
100
-
101
  # Define the aggregate pipeline with the vector search stage and additional stages
102
  pipeline = [vector_search_stage] + additional_stages
103
 
@@ -113,34 +144,14 @@ def vector_search_with_filter(user_query, db, collection, additional_stages=[],
113
  verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
114
 
115
  vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
116
- millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
117
 
118
- print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")
 
 
119
 
120
  return list(results)
121
 
122
-
123
-
124
-
125
- def connect_to_database():
126
- """Establish connection to the MongoDB."""
127
-
128
- MONGO_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
129
-
130
- if not MONGO_URI:
131
- print("MONGO_URI not set in environment variables")
132
-
133
- # gateway to interacting with a MongoDB database cluster
134
- mongo_client = MongoClient(MONGO_URI, appname="advanced-rag")
135
- print("Connection to MongoDB successful")
136
-
137
- # Pymongo client of database and collection
138
- db = mongo_client.get_database(DB_NAME)
139
- collection = db.get_collection(COLLECTION_NAME)
140
-
141
- return db, collection
142
-
143
- def vector_search(user_query, db, collection, additional_stages=[], vector_index="vector_index_text"):
144
  """
145
  Perform a vector search in the MongoDB collection based on the user query.
146
 
@@ -163,14 +174,21 @@ def vector_search(user_query, db, collection, additional_stages=[], vector_index
163
  # Define the vector search stage
164
  vector_search_stage = {
165
  "$vectorSearch": {
166
- "index": vector_index, # specifies the index to use for the search
167
- "queryVector": query_embedding, # the vector representing the query
168
- "path": "text_embeddings", # field in the documents containing the vectors to search against
169
- "numCandidates": 150, # number of candidate matches to consider
170
- "limit": 20, # return top 20 matches
 
 
 
 
 
 
171
  }
172
  }
173
 
 
174
  # Define the aggregate pipeline with the vector search stage and additional stages
175
  pipeline = [vector_search_stage] + additional_stages
176
 
@@ -186,57 +204,26 @@ def vector_search(user_query, db, collection, additional_stages=[], vector_index
186
  verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
187
 
188
  vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
 
189
 
190
- #millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
191
- print(vector_search_explain)
192
- #print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")
193
 
194
  return list(results)
195
 
196
- def rag_retrieval(prompt, db, collection, stages=[], vector_index="vector_index"):
197
- # Assuming vector_search returns a list of dictionaries with keys 'title' and 'plot'
198
- get_knowledge = vector_search(prompt, db, collection, stages, vector_index)
199
-
200
- # Check if there are any results
201
- if not get_knowledge:
202
- return "No results found.", "No source information available."
203
-
204
- # Convert search results into a list of SearchResultItem models
205
- search_results_models = [
206
- SearchResultItem(**result)
207
- for result in get_knowledge
208
- ]
209
-
210
- # Convert search results into a DataFrame for better rendering in Jupyter
211
- search_results_df = pd.DataFrame([item.dict() for item in search_results_models])
212
 
213
- return search_results_df
 
 
214
 
215
- def rag_inference(openai_api_key, prompt, search_results_df):
216
  openai.api_key = openai_api_key
217
-
218
- # Generate system response using OpenAI's completion
219
- content = f"Answer this user question: {prompt} with the following context:\n{search_results_df}"
220
-
221
- completion = openai.chat.completions.create(
222
- model="gpt-4o",
223
- messages=[
224
- {
225
- "role": "system",
226
- "content": "You are an AirBnB listing recommendation system."},
227
- {
228
- "role": "user",
229
- "content": content
230
- }
231
- ]
232
- )
233
-
234
- result = completion.choices[0].message.content
235
-
236
- print("###")
237
- print(f"- User Content:\n{content}\n")
238
- print("###")
239
- print(f"- Prompt Completion:\n{result}\n")
240
- print("###")
241
-
242
- return result
 
12
  DB_NAME = "airbnb_dataset"
13
  COLLECTION_NAME = "listings_reviews"
14
 
15
+ def connect_to_database():
16
+ MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
17
+
18
+ mongo_client = MongoClient(MONGODB_ATLAS_CLUSTER_URI, appname="advanced-rag")
19
+ db = mongo_client.get_database(DB_NAME)
20
+ collection = db.get_collection(COLLECTION_NAME)
21
+ return db, collection
22
+
23
  def rag_ingestion(collection):
24
  dataset = load_dataset("MongoDB/airbnb_embeddings", streaming=True, split="train")
25
  dataset_df = pd.DataFrame(dataset)
 
28
  collection.insert_many(listings)
29
  # Manually create a vector search index (in free tier, this feature is not available via SDK)
30
 
31
+ def rag_retrieval(openai_api_key, prompt, db, collection, stages=[], vector_index="vector_index"):
32
+ # Assuming vector_search returns a list of dictionaries with keys 'title' and 'plot'
33
+ get_knowledge = vector_search(openai_api_key, prompt, db, collection, stages, vector_index)
34
+
35
+ # Check if there are any results
36
+ if not get_knowledge:
37
+ return "No results found.", "No source information available."
38
+
39
+ # Convert search results into a list of SearchResultItem models
40
+ search_results_models = [
41
+ SearchResultItem(**result)
42
+ for result in get_knowledge
43
+ ]
44
+
45
+ # Convert search results into a DataFrame for better rendering in Jupyter
46
+ search_results_df = pd.DataFrame([item.dict() for item in search_results_models])
47
+
48
+ return search_results_df
49
+
50
+ def rag_inference(openai_api_key, prompt, search_results_df):
51
+ openai.api_key = openai_api_key
52
+
53
+ # Generate system response using OpenAI's completion
54
+ content = f"Answer this user question: {prompt} with the following context:\n{search_results_df}"
55
+
56
+ completion = openai.chat.completions.create(
57
+ model="gpt-4o",
58
+ messages=[
59
+ {
60
+ "role": "system",
61
+ "content": "You are an AirBnB listing recommendation system."},
62
+ {
63
+ "role": "user",
64
+ "content": content
65
+ }
66
+ ]
67
+ )
68
+
69
+ result = completion.choices[0].message.content
70
+
71
+ print("###")
72
+ print(f"- User Content:\n{content}\n")
73
+ print("###")
74
+ print(f"- Prompt Completion:\n{result}\n")
75
+ print("###")
76
+
77
+ return result
78
+
79
  def process_records(data_frame):
80
  records = data_frame.to_dict(orient="records")
81
 
 
92
  record[key] = None
93
 
94
  try:
 
95
  # Convert each dictionary to a Listing instance
96
  return [Listing(**record).dict() for record in records]
97
  except ValidationError as e:
98
  print("Validation error:", e)
99
  return []
100
 
101
+ def vector_search(openai_api_key, user_query, db, collection, additional_stages=[], vector_index="vector_index_text"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  """
103
  Perform a vector search in the MongoDB collection based on the user query.
104
 
 
113
  """
114
 
115
  # Generate embedding for the user query
116
+ query_embedding = get_embedding(openai_api_key, user_query)
117
 
118
  if query_embedding is None:
119
  return "Invalid query or embedding generation failed."
 
121
  # Define the vector search stage
122
  vector_search_stage = {
123
  "$vectorSearch": {
124
+ "index": vector_index, # specifies the index to use for the search
125
+ "queryVector": query_embedding, # the vector representing the query
126
+ "path": "text_embeddings", # field in the documents containing the vectors to search against
127
+ "numCandidates": 150, # number of candidate matches to consider
128
+ "limit": 20, # return top 20 matches
 
 
 
 
 
 
129
  }
130
  }
131
 
 
132
  # Define the aggregate pipeline with the vector search stage and additional stages
133
  pipeline = [vector_search_stage] + additional_stages
134
 
 
144
  verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
145
 
146
  vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
 
147
 
148
+ #millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
149
+ print(vector_search_explain)
150
+ #print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")
151
 
152
  return list(results)
153
 
154
+ def vector_search_with_filter(user_query, db, collection, additional_stages=[], vector_index="vector_index_2"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  """
156
  Perform a vector search in the MongoDB collection based on the user query.
157
 
 
174
  # Define the vector search stage
175
  vector_search_stage = {
176
  "$vectorSearch": {
177
+ "index": vector_index, # specifies the index to use for the search
178
+ "queryVector": query_embedding, # the vector representing the query
179
+ "path": "text_embeddings", # field in the documents containing the vectors to search against
180
+ "numCandidates": 150, # number of candidate matches to consider
181
+ "limit": 20, # return top 20 matches
182
+ "filter": {
183
+ "$and": [
184
+ {"accommodates": {"$gte": 2}},
185
+ {"bedrooms": {"$lte": 7}}
186
+ ]
187
+ },
188
  }
189
  }
190
 
191
+
192
  # Define the aggregate pipeline with the vector search stage and additional stages
193
  pipeline = [vector_search_stage] + additional_stages
194
 
 
204
  verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
205
 
206
  vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
207
+ millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
208
 
209
+ print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")
 
 
210
 
211
  return list(results)
212
 
213
+ def get_embedding(openai_api_key, text):
214
+ """Generate an embedding for the given text using OpenAI's API."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
+ # Check for valid input
217
+ if not text or not isinstance(text, str):
218
+ return None
219
 
 
220
  openai.api_key = openai_api_key
221
+
222
+ try:
223
+ embedding = openai.embeddings.create(
224
+ input=text,
225
+ model="text-embedding-3-small", dimensions=1536).data[0].embedding
226
+ return embedding
227
+ except Exception as e:
228
+ print(f"Error in get_embedding: {e}")
229
+ return None