File size: 8,711 Bytes
dae5fc7
36320a3
dae5fc7
f750f67
dae5fc7
36320a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6608ce6
36320a3
 
 
 
 
6608ce6
36320a3
 
 
 
 
 
cf8a3af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5927548
cf8a3af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb0447c
 
 
cf8a3af
 
 
150d83b
95d89c2
 
cf8a3af
92262b1
cf8a3af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3eb2c0f
 
836ca34
4208e5e
cf8a3af
 
 
4208e5e
cf8a3af
 
3eb2c0f
cf8a3af
 
 
 
c0709ce
cf8a3af
3eb2c0f
 
 
 
 
cf8a3af
c0709ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import openai, os, time
import pandas as pd

from document_model import SearchResultItem
from pydantic import ValidationError
from pymongo.collection import Collection
from pymongo.errors import OperationFailure
from pymongo.operations import SearchIndexModel
from pymongo.mongo_client import MongoClient

DB_NAME = "airbnb_dataset"
COLLECTION_NAME = "listings_reviews"

def process_records(data_frame):
    records = data_frame.to_dict(orient='records')
    # Handle potential `NaT` values
    for record in records:
        for key, value in record.items():
            # Check if the value is list-like; if so, process each element.
            if isinstance(value, list):
                processed_list = [None if pd.isnull(v) else v for v in value]
                record[key] = processed_list
            # For scalar values, continue as before.
            else:
                if pd.isnull(value):
                    record[key] = None
    try:
        # Convert each dictionary to a Listing instance
        listings = [Listing(**record).dict() for record in records]
        return listings
    except ValidationError as e:
        print("Validation error:", e)
        return []
    


def get_embedding(text):
    """Generate an embedding for the given text using OpenAI's API."""

    # Check for valid input
    if not text or not isinstance(text, str):
        return None

    try:
        # Call OpenAI API to get the embedding
        embedding = openai.embeddings.create(
            input=text,
            model="text-embedding-3-small", dimensions=1536).data[0].embedding
        return embedding
    except Exception as e:
        print(f"Error in get_embedding: {e}")
        return None

def vector_search_with_filter(user_query, db, collection, additional_stages=[], vector_index="vector_index_text"):
    """
    Perform a vector search in the MongoDB collection based on the user query.

    Args:
    user_query (str): The user's query string.
    db (MongoClient.database): The database object.
    collection (MongoCollection): The MongoDB collection to search.
    additional_stages (list): Additional aggregation stages to include in the pipeline.

    Returns:
    list: A list of matching documents.
    """

    # Generate embedding for the user query
    query_embedding = get_embedding(user_query)

    if query_embedding is None:
        return "Invalid query or embedding generation failed."

    # Define the vector search stage
    vector_search_stage = {
        "$vectorSearch": {
            "index": vector_index,  # specifies the index to use for the search
            "queryVector": query_embedding,  # the vector representing the query
            "path": "text_embeddings",  # field in the documents containing the vectors to search against
            "numCandidates": 150,  # number of candidate matches to consider
            "limit": 20,  # return top 20 matches
            "filter": {
                "$and": [
                    {"accommodates": {"$gte": 2}}, 
                    {"bedrooms": {"$lte": 7}}
                ]
            },
        }
    }


    # Define the aggregate pipeline with the vector search stage and additional stages
    pipeline = [vector_search_stage] + additional_stages

    # Execute the search
    results = collection.aggregate(pipeline)

    explain_query_execution = db.command( # sends a database command directly to the MongoDB server
        'explain', { # return information about how MongoDB executes a query or command without actually running it
            'aggregate': collection.name, # specifies the name of the collection on which the aggregation is performed
            'pipeline': pipeline, # the aggregation pipeline to analyze
            'cursor': {} # indicates that default cursor behavior should be used
        }, 
        verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline

    vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
    millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']

    print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")

    return list(results)




def connect_to_database():
    """Establish connection to the MongoDB."""

    MONGO_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]

    if not MONGO_URI:
        print("MONGO_URI not set in environment variables")

    # gateway to interacting with a MongoDB database cluster
    mongo_client = MongoClient(MONGO_URI, appname="advanced-rag")
    print("Connection to MongoDB successful")

    # Pymongo client of database and collection
    db = mongo_client.get_database(DB_NAME)
    collection = db.get_collection(COLLECTION_NAME)

    return db, collection

def vector_search(user_query, db, collection, additional_stages=[], vector_index="vector_index_text"):
    """
    Perform a vector search in the MongoDB collection based on the user query.

    Args:
    user_query (str): The user's query string.
    db (MongoClient.database): The database object.
    collection (MongoCollection): The MongoDB collection to search.
    additional_stages (list): Additional aggregation stages to include in the pipeline.

    Returns:
    list: A list of matching documents.
    """

    # Generate embedding for the user query
    query_embedding = get_embedding(user_query)

    if query_embedding is None:
        return "Invalid query or embedding generation failed."

    # Define the vector search stage
    vector_search_stage = {
        "$vectorSearch": {
            "index": vector_index, # specifies the index to use for the search
            "queryVector": query_embedding, # the vector representing the query
            "path": "text_embeddings", # field in the documents containing the vectors to search against
            "numCandidates": 150, # number of candidate matches to consider
            "limit": 20, # return top 20 matches
        }
    }

    # Define the aggregate pipeline with the vector search stage and additional stages
    pipeline = [vector_search_stage] + additional_stages

    # Execute the search
    results = collection.aggregate(pipeline)

    explain_query_execution = db.command( # sends a database command directly to the MongoDB server
        'explain', { # return information about how MongoDB executes a query or command without actually running it
            'aggregate': collection.name, # specifies the name of the collection on which the aggregation is performed
            'pipeline': pipeline, # the aggregation pipeline to analyze
            'cursor': {} # indicates that default cursor behavior should be used
        }, 
        verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline

    vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']

    #millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
    print(vector_search_explain)
    #print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")

    return list(results)

def handle_user_prompt(openai_api_key, prompt, db, collection, stages=[], vector_index="vector_index"):
    openai.api_key = openai_api_key
    
    # Assuming vector_search returns a list of dictionaries with keys 'title' and 'plot'
    get_knowledge = vector_search(prompt, db, collection, stages, vector_index)

    # Check if there are any results
    if not get_knowledge:
        return "No results found.", "No source information available."

    # Convert search results into a list of SearchResultItem models
    search_results_models = [
        SearchResultItem(**result)
        for result in get_knowledge
    ]

    # Convert search results into a DataFrame for better rendering in Jupyter
    search_results_df = pd.DataFrame([item.dict() for item in search_results_models])

    # Generate system response using OpenAI's completion
    content = f"Answer this user question: {prompt} with the following context:\n{search_results_df}"
    
    completion = openai.chat.completions.create(
        model="gpt-4o",
        messages=[
            {
                "role": "system", 
                "content": "You are an AirBnB listing recommendation system."},
            {
                "role": "user", 
                "content": content
            }
        ]
    )

    result = completion.choices[0].message.content

    print("###")
    print(f"- User Content:\n{content}\n")
    print("###")
    print(f"- Prompt Completion:\n{result}\n")
    print("###")

    return result