advanced-rag / custom_utils.py
bstraehle's picture
Update custom_utils.py
979095d verified
raw
history blame
7.2 kB
import openai, os, time
#import pandas as pd
from datasets import load_dataset
#from pydantic import ValidationError
from pymongo.collection import Collection
from pymongo.errors import OperationFailure
from pymongo.mongo_client import MongoClient
from pymongo.operations import SearchIndexModel
DB_NAME = "airbnb_dataset"
COLLECTION_NAME = "listings_reviews"
def connect_to_database():
MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
mongo_client = MongoClient(MONGODB_ATLAS_CLUSTER_URI, appname="advanced-rag")
db = mongo_client.get_database(DB_NAME)
collection = db.get_collection(COLLECTION_NAME)
return db, collection
def rag_ingestion(collection):
dataset = load_dataset("bstraehle/airbnb-san-francisco-202403-embed", streaming=True, split="train")
collection.delete_many({})
collection.insert_many(dataset)
return "Manually create a vector search index (in free tier, this feature is not available via SDK)"
def rag_retrieval(openai_api_key,
prompt,
accomodates,
bedrooms,
db,
collection,
vector_index="vector_index"):
###
### Pre-retrieval processing: index filter
### Post-retrieval processing: result filter
#match_stage = {
# "$match": {
# "accommodates": { "$eq": 2},
# "bedrooms": { "$eq": 1}
# }
#}
#additional_stages = [match_stage]
###
"""
projection_stage = {
"$project": {
"_id": 0,
"name": 1,
"accommodates": 1,
"address.street": 1,
"address.government_area": 1,
"address.market": 1,
"address.country": 1,
"address.country_code": 1,
"address.location.type": 1,
"address.location.coordinates": 1,
"address.location.is_location_exact": 1,
"summary": 1,
"space": 1,
"neighborhood_overview": 1,
"notes": 1,
"score": {"$meta": "vectorSearchScore"}
}
}
additional_stages = [projection_stage]
"""
###
#review_average_stage = {
# "$addFields": {
# "averageReviewScore": {
# "$divide": [
# {
# "$add": [
# "$review_scores.review_scores_accuracy",
# "$review_scores.review_scores_cleanliness",
# "$review_scores.review_scores_checkin",
# "$review_scores.review_scores_communication",
# "$review_scores.review_scores_location",
# "$review_scores.review_scores_value",
# ]
# },
# 6 # Divide by the number of review score types to get the average
# ]
# },
# # Calculate a score boost factor based on the number of reviews
# "reviewCountBoost": "$number_of_reviews"
# }
#}
#weighting_stage = {
# "$addFields": {
# "combinedScore": {
# # Example formula that combines average review score and review count boost
# "$add": [
# {"$multiply": ["$averageReviewScore", 0.9]}, # Weighted average review score
# {"$multiply": ["$reviewCountBoost", 0.1]} # Weighted review count boost
# ]
# }
# }
#}
# Apply the combinedScore for sorting
#sorting_stage_sort = {
# "$sort": {"combinedScore": -1} # Descending order to boost higher combined scores
#}
#additional_stages = [review_average_stage, weighting_stage, sorting_stage_sort]
###
additional_stages = []
###
###
get_knowledge = vector_search(
openai_api_key,
prompt,
accomodates,
bedrooms,
db,
collection,
additional_stages,
vector_index)
if not get_knowledge:
return "No results found.", "No source information available."
print("###")
print(get_knowledge)
print("###")
return get_knowledge
def rag_inference(openai_api_key,
prompt,
search_results):
openai.api_key = openai_api_key
content = f"Answer this user question: {prompt} with the following context:\n{search_results}"
completion = openai.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": "You are an AirBnB listing recommendation system."},
{
"role": "user",
"content": content
}
]
)
return completion.choices[0].message.content
def vector_search(openai_api_key,
user_query,
accommodates,
bedrooms,
db,
collection,
additional_stages=[],
vector_index="vector_index"):
query_embedding = get_text_embedding(openai_api_key, user_query)
if query_embedding is None:
return "Invalid query or embedding generation failed."
#vector_search_stage = {
# "$vectorSearch": {
# "index": vector_index,
# "queryVector": query_embedding,
# "path": "description_embedding",
# "numCandidates": 150,
# "limit": 3,
# }
#}
vector_search_stage = {
"$vectorSearch": {
"index": vector_index,
"queryVector": query_embedding,
"path": "description_embedding",
"numCandidates": 150,
"limit": 10,
"filter": {
"$and": [
{"accommodates": {"$eq": accommodates}},
{"bedrooms": {"$eq": bedrooms}}
]
},
}
}
remove_embedding_stage = {
"$unset": "description_embedding"
}
pipeline = [vector_search_stage, remove_embedding_stage]# + additional_stages
results = collection.aggregate(pipeline)
#explain_query_execution = db.command(
# "explain", {
# "aggregate": collection.name,
# "pipeline": pipeline,
# "cursor": {}
# },
# verbosity='executionStats')
#vector_search_explain = explain_query_execution["stages"][0]["$vectorSearch"]
#millis_elapsed = vector_search_explain["explain"]["collectStats"]["millisElapsed"]
#print(f"Query execution time: {millis_elapsed} milliseconds")
return list(results)
def get_text_embedding(openai_api_key, text):
if not text or not isinstance(text, str):
return None
openai.api_key = openai_api_key
try:
embedding = openai.embeddings.create(
input=text,
model="text-embedding-3-small", dimensions=1536).data[0].embedding
return embedding
except Exception as e:
print(f"Error in get_embedding: {e}")
return None