bstraehle commited on
Commit
36320a3
·
verified ·
1 Parent(s): a635d90

Create util.py

Browse files
Files changed (1) hide show
  1. util.py +268 -0
util.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+ from pydantic import BaseModel, ValidationError
4
+ from datetime import datetime
5
+ import pandas as pd
6
+ import openai
7
+ from pymongo.collection import Collection
8
+ from pymongo.errors import OperationFailure
9
+ from pymongo.operations import SearchIndexModel
10
+ from pymongo.mongo_client import MongoClient
11
+ import time
12
+
13
+ from dotenv import load_dotenv, find_dotenv
14
+ _ = load_dotenv(find_dotenv()) # read local .env file
15
+ openai.api_key = os.environ['OPENAI_API_KEY']
16
+
17
+ DB_NAME = "airbnb_dataset"
18
+ COLLECTION_NAME = "listings_reviews"
19
+
20
+ class Host(BaseModel):
21
+ host_id: str
22
+ host_url: str
23
+ host_name: str
24
+ host_location: str
25
+ host_about: str
26
+ host_response_time: Optional[str] = None
27
+ host_thumbnail_url: str
28
+ host_picture_url: str
29
+ host_response_rate: Optional[int] = None
30
+ host_is_superhost: bool
31
+ host_has_profile_pic: bool
32
+ host_identity_verified: bool
33
+
34
+ class Location(BaseModel):
35
+ type: str
36
+ coordinates: List[float]
37
+ is_location_exact: bool
38
+
39
+ class Address(BaseModel):
40
+ street: str
41
+ government_area: str
42
+ market: str
43
+ country: str
44
+ country_code: str
45
+ location: Location
46
+
47
+ class Review(BaseModel):
48
+ _id: str
49
+ date: Optional[datetime] = None
50
+ listing_id: str
51
+ reviewer_id: str
52
+ reviewer_name: Optional[str] = None
53
+ comments: Optional[str] = None
54
+
55
+ class Listing(BaseModel):
56
+ _id: int
57
+ listing_url: str
58
+ name: str
59
+ summary: str
60
+ space: str
61
+ description: str
62
+ neighborhood_overview: Optional[str] = None
63
+ notes: Optional[str] = None
64
+ transit: Optional[str] = None
65
+ access: str
66
+ interaction: Optional[str] = None
67
+ house_rules: str
68
+ property_type: str
69
+ room_type: str
70
+ bed_type: str
71
+ minimum_nights: int
72
+ maximum_nights: int
73
+ cancellation_policy: str
74
+ last_scraped: Optional[datetime] = None
75
+ calendar_last_scraped: Optional[datetime] = None
76
+ first_review: Optional[datetime] = None
77
+ last_review: Optional[datetime] = None
78
+ accommodates: int
79
+ bedrooms: Optional[float] = 0
80
+ beds: Optional[float] = 0
81
+ number_of_reviews: int
82
+ bathrooms: Optional[float] = 0
83
+ amenities: List[str]
84
+ price: int
85
+ security_deposit: Optional[float] = None
86
+ cleaning_fee: Optional[float] = None
87
+ extra_people: int
88
+ guests_included: int
89
+ images: dict
90
+ host: Host
91
+ address: Address
92
+ availability: dict
93
+ review_scores: dict
94
+ reviews: List[Review]
95
+ text_embeddings: List[float]
96
+
97
+ def process_records(data_frame):
98
+ records = data_frame.to_dict(orient='records')
99
+ # Handle potential `NaT` values
100
+ for record in records:
101
+ for key, value in record.items():
102
+ # Check if the value is list-like; if so, process each element.
103
+ if isinstance(value, list):
104
+ processed_list = [None if pd.isnull(v) else v for v in value]
105
+ record[key] = processed_list
106
+ # For scalar values, continue as before.
107
+ else:
108
+ if pd.isnull(value):
109
+ record[key] = None
110
+ try:
111
+ # Convert each dictionary to a Listing instance
112
+ listings = [Listing(**record).dict() for record in records]
113
+ return listings
114
+ except ValidationError as e:
115
+ print("Validation error:", e)
116
+ return []
117
+
118
+
119
+
120
+ def get_embedding(text):
121
+ """Generate an embedding for the given text using OpenAI's API."""
122
+
123
+ # Check for valid input
124
+ if not text or not isinstance(text, str):
125
+ return None
126
+
127
+ try:
128
+ # Call OpenAI API to get the embedding
129
+ embedding = openai.embeddings.create(
130
+ input=text,
131
+ model="text-embedding-3-small", dimensions=1536).data[0].embedding
132
+ return embedding
133
+ except Exception as e:
134
+ print(f"Error in get_embedding: {e}")
135
+ return None
136
+
137
+
138
+ def setup_vector_search_index(collection: Collection,
139
+ text_embedding_field_name: str = "text_embeddings",
140
+ vector_search_index_name: str = "vector_index_text"):
141
+ """
142
+ Sets up a vector search index for a MongoDB collection based on text embeddings.
143
+
144
+ Parameters:
145
+ - collection (Collection): The MongoDB collection to which the index is applied.
146
+ - text_embedding_field_name (str): The field in the documents that contains the text embeddings.
147
+ - vector_search_index_name (str): The name for the vector search index.
148
+
149
+ Returns:
150
+ - None
151
+ """
152
+ # Define the model for the vector search index
153
+ vector_search_index_model = SearchIndexModel(
154
+ definition={
155
+ "mappings": { # describes how fields in the database documents are indexed and stored
156
+ "dynamic": True, # automatically index new fields that appear in the document
157
+ "fields": { # properties of the fields that will be indexed.
158
+ text_embedding_field_name: {
159
+ "dimensions": 1536, # size of the vector.
160
+ "similarity": "cosine", # algorithm used to compute the similarity between vectors
161
+ "type": "knnVector",
162
+ }
163
+ },
164
+ }
165
+ },
166
+ name=vector_search_index_name, # identifier for the vector search index
167
+ )
168
+
169
+ # Check if the index already exists
170
+ index_exists = False
171
+ for index in collection.list_indexes():
172
+ if index['name'] == vector_search_index_name:
173
+ index_exists = True
174
+ break
175
+
176
+ # Create the index if it doesn't exist
177
+ if not index_exists:
178
+ try:
179
+ result = collection.create_search_index(vector_search_index_model)
180
+ print("Creating index...")
181
+ time.sleep(20) # Sleep for 20 seconds, adding sleep to ensure vector index has compeleted inital sync before utilization
182
+ print(f"Index created successfully: {result}")
183
+ print("Wait a few minutes before conducting search with index to ensure index initialization.")
184
+ except OperationFailure as e:
185
+ print(f"Error creating vector search index: {str(e)}")
186
+ else:
187
+ print(f"Index '{vector_search_index_name}' already exists.")
188
+
189
+
190
+ def vector_search_with_filter(user_query, db, collection, additional_stages=[], vector_index="vector_index_text"):
191
+ """
192
+ Perform a vector search in the MongoDB collection based on the user query.
193
+
194
+ Args:
195
+ user_query (str): The user's query string.
196
+ db (MongoClient.database): The database object.
197
+ collection (MongoCollection): The MongoDB collection to search.
198
+ additional_stages (list): Additional aggregation stages to include in the pipeline.
199
+
200
+ Returns:
201
+ list: A list of matching documents.
202
+ """
203
+
204
+ # Generate embedding for the user query
205
+ query_embedding = get_embedding(user_query)
206
+
207
+ if query_embedding is None:
208
+ return "Invalid query or embedding generation failed."
209
+
210
+ # Define the vector search stage
211
+ vector_search_stage = {
212
+ "$vectorSearch": {
213
+ "index": vector_index, # specifies the index to use for the search
214
+ "queryVector": query_embedding, # the vector representing the query
215
+ "path": "text_embeddings", # field in the documents containing the vectors to search against
216
+ "numCandidates": 150, # number of candidate matches to consider
217
+ "limit": 20, # return top 20 matches
218
+ "filter": {
219
+ "$and": [
220
+ {"accommodates": {"$gte": 2}},
221
+ {"bedrooms": {"$lte": 7}}
222
+ ]
223
+ },
224
+ }
225
+ }
226
+
227
+
228
+ # Define the aggregate pipeline with the vector search stage and additional stages
229
+ pipeline = [vector_search_stage] + additional_stages
230
+
231
+ # Execute the search
232
+ results = collection.aggregate(pipeline)
233
+
234
+ explain_query_execution = db.command( # sends a database command directly to the MongoDB server
235
+ 'explain', { # return information about how MongoDB executes a query or command without actually running it
236
+ 'aggregate': collection.name, # specifies the name of the collection on which the aggregation is performed
237
+ 'pipeline': pipeline, # the aggregation pipeline to analyze
238
+ 'cursor': {} # indicates that default cursor behavior should be used
239
+ },
240
+ verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
241
+
242
+ vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
243
+ millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
244
+
245
+ print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")
246
+
247
+ return list(results)
248
+
249
+
250
+
251
+
252
+ def connect_to_database():
253
+ """Establish connection to the MongoDB."""
254
+
255
+ MONGO_URI = os.environ.get("MONGO_URI")
256
+
257
+ if not MONGO_URI:
258
+ print("MONGO_URI not set in environment variables")
259
+
260
+ # gateway to interacting with a MongoDB database cluster
261
+ mongo_client = MongoClient(MONGO_URI, appname="devrel.deeplearningai.python")
262
+ print("Connection to MongoDB successful")
263
+
264
+ # Pymongo client of database and collection
265
+ db = mongo_client.get_database(DB_NAME)
266
+ collection = db.get_collection(COLLECTION_NAME)
267
+
268
+ return db, collection