bstraehle commited on
Commit
380302b
·
verified ·
1 Parent(s): 7e7435e

Update custom_utils.py

Browse files
Files changed (1) hide show
  1. custom_utils.py +49 -125
custom_utils.py CHANGED
@@ -28,22 +28,23 @@ def rag_retrieval_naive(openai_api_key,
28
  db,
29
  collection,
30
  vector_index="vector_index"):
31
-
32
- get_knowledge = vector_search_naive(
33
  openai_api_key,
34
  prompt,
35
  db,
36
  collection,
37
- vector_index)
 
38
 
39
- if not get_knowledge:
40
- return "No results found.", "No source information available."
41
 
42
- print("###")
43
- print(get_knowledge)
44
- print("###")
45
 
46
- return get_knowledge
47
 
48
  def rag_retrieval_advanced(openai_api_key,
49
  prompt,
@@ -52,43 +53,23 @@ def rag_retrieval_advanced(openai_api_key,
52
  db,
53
  collection,
54
  vector_index="vector_index"):
55
- ###
56
- ### Pre-retrieval processing: index filter
57
- ### Post-retrieval processing: result filter
 
 
 
58
  #match_stage = {
59
  # "$match": {
60
  # "accommodates": { "$eq": 2},
61
  # "bedrooms": { "$eq": 1}
62
  # }
63
  #}
64
-
65
- #additional_stages = [match_stage]
66
- ###
67
- """
68
- projection_stage = {
69
- "$project": {
70
- "_id": 0,
71
- "name": 1,
72
- "accommodates": 1,
73
- "address.street": 1,
74
- "address.government_area": 1,
75
- "address.market": 1,
76
- "address.country": 1,
77
- "address.country_code": 1,
78
- "address.location.type": 1,
79
- "address.location.coordinates": 1,
80
- "address.location.is_location_exact": 1,
81
- "summary": 1,
82
- "space": 1,
83
- "neighborhood_overview": 1,
84
- "notes": 1,
85
- "score": {"$meta": "vectorSearchScore"}
86
- }
87
- }
88
 
89
- additional_stages = [projection_stage]
90
- """
91
- ###
 
92
  review_average_stage = {
93
  "$addFields": {
94
  "averageReviewScore": {
@@ -104,10 +85,9 @@ def rag_retrieval_advanced(openai_api_key,
104
  "$review_scores_value",
105
  ]
106
  },
107
- 6 # Divide by the number of review score types to get the average
108
  ]
109
  },
110
- # Calculate a score boost factor based on the number of reviews
111
  "reviewCountBoost": "$number_of_reviews"
112
  }
113
  }
@@ -115,27 +95,21 @@ def rag_retrieval_advanced(openai_api_key,
115
  weighting_stage = {
116
  "$addFields": {
117
  "combinedScore": {
118
- # Example formula that combines average review score and review count boost
119
  "$add": [
120
- {"$multiply": ["$averageReviewScore", 0.9]}, # Weighted average review score
121
- {"$multiply": ["$reviewCountBoost", 0.1]} # Weighted review count boost
122
  ]
123
  }
124
  }
125
  }
126
 
127
- # Apply the combinedScore for sorting
128
  sorting_stage_sort = {
129
- "$sort": {"combinedScore": -1} # Descending order to boost higher combined scores
130
  }
131
 
132
  additional_stages = [review_average_stage, weighting_stage, sorting_stage_sort]
133
- ###
134
- #additional_stages = []
135
- ###
136
- ###
137
 
138
- get_knowledge = vector_search_advanced(
139
  openai_api_key,
140
  prompt,
141
  accomodates,
@@ -143,45 +117,29 @@ def rag_retrieval_advanced(openai_api_key,
143
  db,
144
  collection,
145
  additional_stages,
146
- vector_index)
 
147
 
148
- if not get_knowledge:
149
- return "No results found.", "No source information available."
150
 
151
- print("###")
152
- print(get_knowledge)
153
- print("###")
154
 
155
- return get_knowledge
156
 
157
- def rag_inference(openai_api_key,
158
- prompt,
159
- search_results):
160
- openai.api_key = openai_api_key
161
-
162
- content = f"Answer this user question: {prompt} with the following context:\n{search_results}"
163
-
164
- completion = openai.chat.completions.create(
165
- model="gpt-4o",
166
- messages=[
167
- {
168
- "role": "system",
169
- "content": "You are an AirBnB listing recommendation system."},
170
- {
171
- "role": "user",
172
- "content": content
173
- }
174
- ]
175
- )
176
 
177
- return completion.choices[0].message.content
 
 
178
 
179
- def inference(openai_api_key,
180
- prompt):
181
  openai.api_key = openai_api_key
182
-
183
- content = f"Answer this user question: {prompt}"
184
-
185
  completion = openai.chat.completions.create(
186
  model="gpt-4o",
187
  messages=[
@@ -196,7 +154,7 @@ def inference(openai_api_key,
196
  )
197
 
198
  return completion.choices[0].message.content
199
-
200
  def vector_search_naive(openai_api_key,
201
  user_query,
202
  db,
@@ -223,21 +181,7 @@ def vector_search_naive(openai_api_key,
223
 
224
  pipeline = [vector_search_stage, remove_embedding_stage]
225
 
226
- results = collection.aggregate(pipeline)
227
-
228
- #explain_query_execution = db.command(
229
- # "explain", {
230
- # "aggregate": collection.name,
231
- # "pipeline": pipeline,
232
- # "cursor": {}
233
- # },
234
- # verbosity='executionStats')
235
-
236
- #vector_search_explain = explain_query_execution["stages"][0]["$vectorSearch"]
237
- #millis_elapsed = vector_search_explain["explain"]["collectStats"]["millisElapsed"]
238
- #print(f"Query execution time: {millis_elapsed} milliseconds")
239
-
240
- return list(results)
241
 
242
  def vector_search_advanced(openai_api_key,
243
  user_query,
@@ -252,16 +196,6 @@ def vector_search_advanced(openai_api_key,
252
  if query_embedding is None:
253
  return "Invalid query or embedding generation failed."
254
 
255
- #vector_search_stage = {
256
- # "$vectorSearch": {
257
- # "index": vector_index,
258
- # "queryVector": query_embedding,
259
- # "path": "description_embedding",
260
- # "numCandidates": 150,
261
- # "limit": 25,
262
- # }
263
- #}
264
-
265
  vector_search_stage = {
266
  "$vectorSearch": {
267
  "index": vector_index,
@@ -284,20 +218,10 @@ def vector_search_advanced(openai_api_key,
284
 
285
  pipeline = [vector_search_stage, remove_embedding_stage] + additional_stages
286
 
287
- results = collection.aggregate(pipeline)
288
-
289
- #explain_query_execution = db.command(
290
- # "explain", {
291
- # "aggregate": collection.name,
292
- # "pipeline": pipeline,
293
- # "cursor": {}
294
- # },
295
- # verbosity='executionStats')
296
-
297
- #vector_search_explain = explain_query_execution["stages"][0]["$vectorSearch"]
298
- #millis_elapsed = vector_search_explain["explain"]["collectStats"]["millisElapsed"]
299
- #print(f"Query execution time: {millis_elapsed} milliseconds")
300
 
 
 
301
  return list(results)
302
 
303
  def get_text_embedding(openai_api_key, text):
@@ -307,10 +231,10 @@ def get_text_embedding(openai_api_key, text):
307
  openai.api_key = openai_api_key
308
 
309
  try:
310
- embedding = openai.embeddings.create(
311
  input=text,
312
- model="text-embedding-3-small", dimensions=1536).data[0].embedding
313
- return embedding
314
  except Exception as e:
315
  print(f"Error in get_embedding: {e}")
316
  return None
 
28
  db,
29
  collection,
30
  vector_index="vector_index"):
31
+ # Naive RAG: Semantic search
32
+ retrieval_result = vector_search_naive(
33
  openai_api_key,
34
  prompt,
35
  db,
36
  collection,
37
+ vector_index
38
+ )
39
 
40
+ if not retrieval_result:
41
+ return "No results found."
42
 
43
+ #print("###")
44
+ #print(retrieval_result)
45
+ #print("###")
46
 
47
+ return retrieval_result
48
 
49
  def rag_retrieval_advanced(openai_api_key,
50
  prompt,
 
53
  db,
54
  collection,
55
  vector_index="vector_index"):
56
+ # Advanced RAG: Semantic search plus...
57
+
58
+ # 1a) Pre-retrieval processing: index filter (accomodates, bedrooms) plus...
59
+
60
+ # 1b) Post-retrieval processing: result filter (accomodates, bedrooms) plus...
61
+
62
  #match_stage = {
63
  # "$match": {
64
  # "accommodates": { "$eq": 2},
65
  # "bedrooms": { "$eq": 1}
66
  # }
67
  #}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ #additional_stages = [match_stage]
70
+
71
+ # 2) Average review score and review count boost, sorted in descending order
72
+
73
  review_average_stage = {
74
  "$addFields": {
75
  "averageReviewScore": {
 
85
  "$review_scores_value",
86
  ]
87
  },
88
+ 7
89
  ]
90
  },
 
91
  "reviewCountBoost": "$number_of_reviews"
92
  }
93
  }
 
95
  weighting_stage = {
96
  "$addFields": {
97
  "combinedScore": {
 
98
  "$add": [
99
+ {"$multiply": ["$averageReviewScore", 0.9]},
100
+ {"$multiply": ["$reviewCountBoost", 0.1]},
101
  ]
102
  }
103
  }
104
  }
105
 
 
106
  sorting_stage_sort = {
107
+ "$sort": {"combinedScore": -1}
108
  }
109
 
110
  additional_stages = [review_average_stage, weighting_stage, sorting_stage_sort]
 
 
 
 
111
 
112
+ retrieval_result = vector_search_advanced(
113
  openai_api_key,
114
  prompt,
115
  accomodates,
 
117
  db,
118
  collection,
119
  additional_stages,
120
+ vector_index
121
+ )
122
 
123
+ if not retrieval_result:
124
+ return "No results found."
125
 
126
+ #print("###")
127
+ #print(retrieval_result)
128
+ #print("###")
129
 
130
+ return retrieval_result
131
 
132
+ def inference(openai_api_key, prompt):
133
+ content = f"Answer this user question: {prompt}"
134
+ return invoke_llm(openai_api_key, content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ def rag_inference(openai_api_key, prompt, retrieval_result):
137
+ content = f"Answer this user question: {prompt} with the following context:\n{retrieval_result}"
138
+ return invoke_llm(openai_api_key, content)
139
 
140
+ def invoke_llm(openai_api_key, content):
 
141
  openai.api_key = openai_api_key
142
+
 
 
143
  completion = openai.chat.completions.create(
144
  model="gpt-4o",
145
  messages=[
 
154
  )
155
 
156
  return completion.choices[0].message.content
157
+
158
  def vector_search_naive(openai_api_key,
159
  user_query,
160
  db,
 
181
 
182
  pipeline = [vector_search_stage, remove_embedding_stage]
183
 
184
+ return invoke_search(collection, pipeline)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  def vector_search_advanced(openai_api_key,
187
  user_query,
 
196
  if query_embedding is None:
197
  return "Invalid query or embedding generation failed."
198
 
 
 
 
 
 
 
 
 
 
 
199
  vector_search_stage = {
200
  "$vectorSearch": {
201
  "index": vector_index,
 
218
 
219
  pipeline = [vector_search_stage, remove_embedding_stage] + additional_stages
220
 
221
+ return invoke_search(collection, pipeline)
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ def invoke_search(collection, pipeline):
224
+ results = collection.aggregate(pipeline)
225
  return list(results)
226
 
227
  def get_text_embedding(openai_api_key, text):
 
231
  openai.api_key = openai_api_key
232
 
233
  try:
234
+ return openai.embeddings.create(
235
  input=text,
236
+ model="text-embedding-3-small", dimensions=1536
237
+ ).data[0].embedding
238
  except Exception as e:
239
  print(f"Error in get_embedding: {e}")
240
  return None