Spaces:
Running
Running
File size: 6,545 Bytes
dae5fc7 36320a3 dae5fc7 c4136a8 2b1e3b6 dae5fc7 36320a3 738445b cd4a9a3 3271e72 0a57b44 3271e72 738445b 0a57b44 738445b 6931de0 738445b 6931de0 738445b 6931de0 738445b 6931de0 738445b 6931de0 738445b 36320a3 2b1e3b6 36320a3 2b1e3b6 36320a3 2b1e3b6 36320a3 2b1e3b6 36320a3 2b1e3b6 36320a3 dd927c9 36320a3 738445b 36320a3 738445b 075f373 19893e0 738445b cf8a3af 7ad4a46 cf8a3af 19893e0 cf8a3af 738445b cf8a3af 738445b 1183b4a 738445b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import openai, os, time
import pandas as pd
from datasets import load_dataset
from document_model import Listing, SearchResultItem
from pydantic import ValidationError
from pymongo.collection import Collection
from pymongo.errors import OperationFailure
from pymongo.operations import SearchIndexModel
from pymongo.mongo_client import MongoClient
DB_NAME = "airbnb_dataset"
COLLECTION_NAME = "listings_reviews"
def connect_to_database():
MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
mongo_client = MongoClient(MONGODB_ATLAS_CLUSTER_URI, appname="advanced-rag")
db = mongo_client.get_database(DB_NAME)
collection = db.get_collection(COLLECTION_NAME)
return db, collection
def rag_ingestion(collection):
dataset = load_dataset("MongoDB/airbnb_embeddings", streaming=True, split="train")
dataset_df = pd.DataFrame(dataset)
listings = process_records(dataset_df)
collection.delete_many({})
collection.insert_many(listings)
return "Manually create a vector search index (in free tier, this feature is not available via SDK)"
def rag_retrieval(openai_api_key, prompt, db, collection, stages=[], vector_index="vector_index"):
# Assuming vector_search returns a list of dictionaries with keys 'title' and 'plot'
get_knowledge = vector_search(openai_api_key, prompt, db, collection, stages, vector_index)
# Check if there are any results
if not get_knowledge:
return "No results found.", "No source information available."
# Convert search results into a list of SearchResultItem models
search_results_models = [
SearchResultItem(**result)
for result in get_knowledge
]
# Convert search results into a DataFrame for better rendering in Jupyter
search_results_df = pd.DataFrame([item.dict() for item in search_results_models])
print("###")
print(search_results_df)
print("###")
return search_results_df
def rag_inference(openai_api_key, prompt, search_results):
openai.api_key = openai_api_key
# Generate system response using OpenAI's completion
content = f"Answer this user question: {prompt} with the following context:\n{search_results}"
completion = openai.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": "You are an AirBnB listing recommendation system."},
{
"role": "user",
"content": content
}
]
)
completion_result = completion.choices[0].message.content
print("###")
print(completion_result)
print("###")
return completion_result
def process_records(data_frame):
records = data_frame.to_dict(orient="records")
# Handle potential NaT values
for record in records:
for key, value in record.items():
# List values
if isinstance(value, list):
processed_list = [None if pd.isnull(v) else v for v in value]
record[key] = processed_list
# Scalar values
else:
if pd.isnull(value):
record[key] = None
try:
# Convert each dictionary to a Listing instance
return [Listing(**record).dict() for record in records]
except ValidationError as e:
print("Validation error:", e)
return []
def vector_search(openai_api_key, user_query, db, collection, additional_stages=[], vector_index="vector_index"):
"""
Perform a vector search in the MongoDB collection based on the user query.
Args:
user_query (str): The user's query string.
db (MongoClient.database): The database object.
collection (MongoCollection): The MongoDB collection to search.
additional_stages (list): Additional aggregation stages to include in the pipeline.
Returns:
list: A list of matching documents.
"""
# Generate embedding for the user query
query_embedding = get_embedding(openai_api_key, user_query)
if query_embedding is None:
return "Invalid query or embedding generation failed."
# Define the vector search stage
vector_search_stage = {
"$vectorSearch": {
"index": vector_index, # specifies the index to use for the search
"queryVector": query_embedding, # the vector representing the query
"path": "text_embeddings", # field in the documents containing the vectors to search against
"numCandidates": 150, # number of candidate matches to consider
"limit": 20, # return top 20 matches
"filter": {
"$and": [
{"accommodates": {"$eq": 2}},
{"bedrooms": {"$eq": 1}}
]
},
}
}
# Define the aggregate pipeline with the vector search stage and additional stages
pipeline = [vector_search_stage] + additional_stages
# Execute the search
results = collection.aggregate(pipeline)
explain_query_execution = db.command( # sends a database command directly to the MongoDB server
'explain', { # return information about how MongoDB executes a query or command without actually running it
'aggregate': collection.name, # specifies the name of the collection on which the aggregation is performed
'pipeline': pipeline, # the aggregation pipeline to analyze
'cursor': {} # indicates that default cursor behavior should be used
},
verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
#millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
print(vector_search_explain)
#print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")
return list(results)
def get_embedding(openai_api_key, text):
"""Generate an embedding for the given text using OpenAI's API."""
# Check for valid input
if not text or not isinstance(text, str):
return None
openai.api_key = openai_api_key
try:
embedding = openai.embeddings.create(
input=text,
model="text-embedding-3-small", dimensions=1536).data[0].embedding
return embedding
except Exception as e:
print(f"Error in get_embedding: {e}")
return None |