Spaces:
Running
Running
import os | |
from typing import List, Optional | |
from pydantic import BaseModel, ValidationError | |
from datetime import datetime | |
import pandas as pd | |
import openai | |
from pymongo.collection import Collection | |
from pymongo.errors import OperationFailure | |
from pymongo.operations import SearchIndexModel | |
from pymongo.mongo_client import MongoClient | |
import time | |
DB_NAME = "airbnb_dataset" | |
COLLECTION_NAME = "listings_reviews" | |
class Host(BaseModel): | |
host_id: str | |
host_url: str | |
host_name: str | |
host_location: str | |
host_about: str | |
host_response_time: Optional[str] = None | |
host_thumbnail_url: str | |
host_picture_url: str | |
host_response_rate: Optional[int] = None | |
host_is_superhost: bool | |
host_has_profile_pic: bool | |
host_identity_verified: bool | |
class Location(BaseModel): | |
type: str | |
coordinates: List[float] | |
is_location_exact: bool | |
class Address(BaseModel): | |
street: str | |
government_area: str | |
market: str | |
country: str | |
country_code: str | |
location: Location | |
class Review(BaseModel): | |
_id: str | |
date: Optional[datetime] = None | |
listing_id: str | |
reviewer_id: str | |
reviewer_name: Optional[str] = None | |
comments: Optional[str] = None | |
class Listing(BaseModel): | |
_id: int | |
listing_url: str | |
name: str | |
summary: str | |
space: str | |
description: str | |
neighborhood_overview: Optional[str] = None | |
notes: Optional[str] = None | |
transit: Optional[str] = None | |
access: str | |
interaction: Optional[str] = None | |
house_rules: str | |
property_type: str | |
room_type: str | |
bed_type: str | |
minimum_nights: int | |
maximum_nights: int | |
cancellation_policy: str | |
last_scraped: Optional[datetime] = None | |
calendar_last_scraped: Optional[datetime] = None | |
first_review: Optional[datetime] = None | |
last_review: Optional[datetime] = None | |
accommodates: int | |
bedrooms: Optional[float] = 0 | |
beds: Optional[float] = 0 | |
number_of_reviews: int | |
bathrooms: Optional[float] = 0 | |
amenities: List[str] | |
price: int | |
security_deposit: Optional[float] = None | |
cleaning_fee: Optional[float] = None | |
extra_people: int | |
guests_included: int | |
images: dict | |
host: Host | |
address: Address | |
availability: dict | |
review_scores: dict | |
reviews: List[Review] | |
text_embeddings: List[float] | |
def process_records(data_frame): | |
records = data_frame.to_dict(orient='records') | |
# Handle potential `NaT` values | |
for record in records: | |
for key, value in record.items(): | |
# Check if the value is list-like; if so, process each element. | |
if isinstance(value, list): | |
processed_list = [None if pd.isnull(v) else v for v in value] | |
record[key] = processed_list | |
# For scalar values, continue as before. | |
else: | |
if pd.isnull(value): | |
record[key] = None | |
try: | |
# Convert each dictionary to a Listing instance | |
listings = [Listing(**record).dict() for record in records] | |
return listings | |
except ValidationError as e: | |
print("Validation error:", e) | |
return [] | |
def get_embedding(text): | |
"""Generate an embedding for the given text using OpenAI's API.""" | |
# Check for valid input | |
if not text or not isinstance(text, str): | |
return None | |
try: | |
# Call OpenAI API to get the embedding | |
embedding = openai.embeddings.create( | |
input=text, | |
model="text-embedding-3-small", dimensions=1536).data[0].embedding | |
return embedding | |
except Exception as e: | |
print(f"Error in get_embedding: {e}") | |
return None | |
def vector_search_with_filter(user_query, db, collection, additional_stages=[], vector_index="vector_index_text"): | |
""" | |
Perform a vector search in the MongoDB collection based on the user query. | |
Args: | |
user_query (str): The user's query string. | |
db (MongoClient.database): The database object. | |
collection (MongoCollection): The MongoDB collection to search. | |
additional_stages (list): Additional aggregation stages to include in the pipeline. | |
Returns: | |
list: A list of matching documents. | |
""" | |
# Generate embedding for the user query | |
query_embedding = get_embedding(user_query) | |
if query_embedding is None: | |
return "Invalid query or embedding generation failed." | |
# Define the vector search stage | |
vector_search_stage = { | |
"$vectorSearch": { | |
"index": vector_index, # specifies the index to use for the search | |
"queryVector": query_embedding, # the vector representing the query | |
"path": "text_embeddings", # field in the documents containing the vectors to search against | |
"numCandidates": 150, # number of candidate matches to consider | |
"limit": 20, # return top 20 matches | |
"filter": { | |
"$and": [ | |
{"accommodates": {"$gte": 2}}, | |
{"bedrooms": {"$lte": 7}} | |
] | |
}, | |
} | |
} | |
# Define the aggregate pipeline with the vector search stage and additional stages | |
pipeline = [vector_search_stage] + additional_stages | |
# Execute the search | |
results = collection.aggregate(pipeline) | |
explain_query_execution = db.command( # sends a database command directly to the MongoDB server | |
'explain', { # return information about how MongoDB executes a query or command without actually running it | |
'aggregate': collection.name, # specifies the name of the collection on which the aggregation is performed | |
'pipeline': pipeline, # the aggregation pipeline to analyze | |
'cursor': {} # indicates that default cursor behavior should be used | |
}, | |
verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline | |
vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch'] | |
millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed'] | |
print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds") | |
return list(results) | |
def connect_to_database(): | |
"""Establish connection to the MongoDB.""" | |
MONGO_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"] | |
if not MONGO_URI: | |
print("MONGO_URI not set in environment variables") | |
# gateway to interacting with a MongoDB database cluster | |
mongo_client = MongoClient(MONGO_URI, appname="advanced-rag") | |
print("Connection to MongoDB successful") | |
# Pymongo client of database and collection | |
db = mongo_client.get_database(DB_NAME) | |
collection = db.get_collection(COLLECTION_NAME) | |
return db, collection | |
def vector_search(user_query, db, collection, additional_stages=[], vector_index="vector_index_text"): | |
""" | |
Perform a vector search in the MongoDB collection based on the user query. | |
Args: | |
user_query (str): The user's query string. | |
db (MongoClient.database): The database object. | |
collection (MongoCollection): The MongoDB collection to search. | |
additional_stages (list): Additional aggregation stages to include in the pipeline. | |
Returns: | |
list: A list of matching documents. | |
""" | |
# Generate embedding for the user query | |
query_embedding = get_embedding(user_query) | |
if query_embedding is None: | |
return "Invalid query or embedding generation failed." | |
# Define the vector search stage | |
vector_search_stage = { | |
"$vectorSearch": { | |
"index": vector_index, # specifies the index to use for the search | |
"queryVector": query_embedding, # the vector representing the query | |
"path": "text_embeddings", # field in the documents containing the vectors to search against | |
"numCandidates": 150, # number of candidate matches to consider | |
"limit": 20, # return top 20 matches | |
} | |
} | |
# Define the aggregate pipeline with the vector search stage and additional stages | |
pipeline = [vector_search_stage] + additional_stages | |
# Execute the search | |
results = collection.aggregate(pipeline) | |
explain_query_execution = db.command( # sends a database command directly to the MongoDB server | |
'explain', { # return information about how MongoDB executes a query or command without actually running it | |
'aggregate': collection.name, # specifies the name of the collection on which the aggregation is performed | |
'pipeline': pipeline, # the aggregation pipeline to analyze | |
'cursor': {} # indicates that default cursor behavior should be used | |
}, | |
verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline | |
vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch'] | |
#millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed'] | |
print(vector_search_explain) | |
#print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds") | |
return list(results) | |
class SearchResultItem(BaseModel): | |
name: str | |
accommodates: Optional[int] = None | |
bedrooms: Optional[int] = None | |
address: Address | |
space: str = None | |
def handle_user_prompt(openai_api_key, prompt, db, collection, stages=[], vector_index="vector_index"): | |
openai.api_key = openai_api_key | |
# Assuming vector_search returns a list of dictionaries with keys 'title' and 'plot' | |
get_knowledge = vector_search(prompt, db, collection, stages, vector_index) | |
# Check if there are any results | |
if not get_knowledge: | |
return "No results found.", "No source information available." | |
# Convert search results into a list of SearchResultItem models | |
search_results_models = [ | |
SearchResultItem(**result) | |
for result in get_knowledge | |
] | |
# Convert search results into a DataFrame for better rendering in Jupyter | |
search_results_df = pd.DataFrame([item.dict() for item in search_results_models]) | |
# Generate system response using OpenAI's completion | |
content = f"Answer this user question: {prompt} with the following context:\n{search_results_df}" | |
completion = openai.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{ | |
"role": "system", | |
"content": "You are an AirBnB listing recommendation system."}, | |
{ | |
"role": "user", | |
"content": content | |
} | |
] | |
) | |
result = completion.choices[0].message.content | |
print("###") | |
print(f"- User Content:\n{content}\n") | |
print("###") | |
print(f"- Prompt Completion:\n{result}\n") | |
print("###") | |
return result |