bstraehle commited on
Commit
d5c11ea
·
verified ·
1 Parent(s): 6a5cc80

Update custom_utils.py

Browse files
Files changed (1) hide show
  1. custom_utils.py +101 -18
custom_utils.py CHANGED
@@ -1,8 +1,6 @@
1
  import openai, os, time
2
- #import pandas as pd
3
 
4
  from datasets import load_dataset
5
- #from pydantic import ValidationError
6
  from pymongo.collection import Collection
7
  from pymongo.errors import OperationFailure
8
  from pymongo.mongo_client import MongoClient
@@ -25,13 +23,35 @@ def rag_ingestion(collection):
25
  collection.insert_many(dataset)
26
  return "Manually create a vector search index (in free tier, this feature is not available via SDK)"
27
 
28
- def rag_retrieval(openai_api_key,
29
- prompt,
30
- accomodates,
31
- bedrooms,
32
- db,
33
- collection,
34
- vector_index="vector_index"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ###
36
  ### Pre-retrieval processing: index filter
37
  ### Post-retrieval processing: result filter
@@ -115,7 +135,7 @@ def rag_retrieval(openai_api_key,
115
  ###
116
  ###
117
 
118
- get_knowledge = vector_search(
119
  openai_api_key,
120
  prompt,
121
  accomodates,
@@ -156,14 +176,77 @@ def rag_inference(openai_api_key,
156
 
157
  return completion.choices[0].message.content
158
 
159
- def vector_search(openai_api_key,
160
- user_query,
161
- accommodates,
162
- bedrooms,
163
- db,
164
- collection,
165
- additional_stages=[],
166
- vector_index="vector_index"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  query_embedding = get_text_embedding(openai_api_key, user_query)
168
 
169
  if query_embedding is None:
 
1
  import openai, os, time
 
2
 
3
  from datasets import load_dataset
 
4
  from pymongo.collection import Collection
5
  from pymongo.errors import OperationFailure
6
  from pymongo.mongo_client import MongoClient
 
23
  collection.insert_many(dataset)
24
  return "Manually create a vector search index (in free tier, this feature is not available via SDK)"
25
 
26
+ def rag_retrieval_naive(openai_api_key,
27
+ prompt,
28
+ db,
29
+ collection,
30
+ vector_index="vector_index"):
31
+
32
+ get_knowledge = vector_search_naive(
33
+ openai_api_key,
34
+ prompt,
35
+ db,
36
+ collection,
37
+ vector_index)
38
+
39
+ if not get_knowledge:
40
+ return "No results found.", "No source information available."
41
+
42
+ print("###")
43
+ print(get_knowledge)
44
+ print("###")
45
+
46
+ return get_knowledge
47
+
48
+ def rag_retrieval_advanced(openai_api_key,
49
+ prompt,
50
+ accomodates,
51
+ bedrooms,
52
+ db,
53
+ collection,
54
+ vector_index="vector_index"):
55
  ###
56
  ### Pre-retrieval processing: index filter
57
  ### Post-retrieval processing: result filter
 
135
  ###
136
  ###
137
 
138
+ get_knowledge = vector_search_advanced(
139
  openai_api_key,
140
  prompt,
141
  accomodates,
 
176
 
177
  return completion.choices[0].message.content
178
 
179
+ def inference(openai_api_key,
180
+ prompt):
181
+ openai.api_key = openai_api_key
182
+
183
+ content = f"Answer this user question: {prompt}"
184
+
185
+ completion = openai.chat.completions.create(
186
+ model="gpt-4o",
187
+ messages=[
188
+ {
189
+ "role": "system",
190
+ "content": "You are an AirBnB listing recommendation system."},
191
+ {
192
+ "role": "user",
193
+ "content": content
194
+ }
195
+ ]
196
+ )
197
+
198
+ return completion.choices[0].message.content
199
+
200
+ def vector_search_naive(openai_api_key,
201
+ user_query,
202
+ db,
203
+ collection,
204
+ vector_index="vector_index"):
205
+ query_embedding = get_text_embedding(openai_api_key, user_query)
206
+
207
+ if query_embedding is None:
208
+ return "Invalid query or embedding generation failed."
209
+
210
+ vector_search_stage = {
211
+ "$vectorSearch": {
212
+ "index": vector_index,
213
+ "queryVector": query_embedding,
214
+ "path": "description_embedding",
215
+ "numCandidates": 150,
216
+ "limit": 25,
217
+ }
218
+ }
219
+
220
+ remove_embedding_stage = {
221
+ "$unset": "description_embedding"
222
+ }
223
+
224
+ pipeline = [vector_search_stage, remove_embedding_stage]
225
+
226
+ results = collection.aggregate(pipeline)
227
+
228
+ #explain_query_execution = db.command(
229
+ # "explain", {
230
+ # "aggregate": collection.name,
231
+ # "pipeline": pipeline,
232
+ # "cursor": {}
233
+ # },
234
+ # verbosity='executionStats')
235
+
236
+ #vector_search_explain = explain_query_execution["stages"][0]["$vectorSearch"]
237
+ #millis_elapsed = vector_search_explain["explain"]["collectStats"]["millisElapsed"]
238
+ #print(f"Query execution time: {millis_elapsed} milliseconds")
239
+
240
+ return list(results)
241
+
242
+ def vector_search_advanced(openai_api_key,
243
+ user_query,
244
+ accommodates,
245
+ bedrooms,
246
+ db,
247
+ collection,
248
+ additional_stages=[],
249
+ vector_index="vector_index"):
250
  query_embedding = get_text_embedding(openai_api_key, user_query)
251
 
252
  if query_embedding is None: