bstraehle commited on
Commit
5ef932e
·
verified ·
1 Parent(s): 800f223

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -36
app.py CHANGED
@@ -1,11 +1,15 @@
1
  import gradio as gr
 
2
  import logging, os, sys, threading
3
 
 
4
  from dotenv import load_dotenv, find_dotenv
5
- from document_model import Listing
6
- from hugging_face_ import get_listings
7
- from mongodb_ import get_db_collection, create_vector_search_index
8
- from openai_ import handle_user_prompt
 
 
9
 
10
  lock = threading.Lock()
11
 
@@ -20,6 +24,108 @@ RAG_ADVANCED = "Advanced RAG"
20
  logging.basicConfig(stream = sys.stdout, level = logging.INFO)
21
  logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def invoke(openai_api_key, prompt, rag_option):
24
  if not openai_api_key:
25
  raise gr.Error("OpenAI API Key is required.")
@@ -37,41 +143,41 @@ def invoke(openai_api_key, prompt, rag_option):
37
  and not too far from resturants, can you recommend a place?
38
  Include a reason as to why you've chosen your selection.
39
  """
40
- listings = get_listings()
41
- db, collection = get_db_collection(listings)
42
- create_vector_search_index(collection)
43
- result = handle_user_prompt(openai_api_key, prompt, db, collection)
44
- ###
 
 
 
 
 
 
 
 
45
 
46
- #del os.environ["OPENAI_API_KEY"]
 
 
 
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
48
  """
49
- if (RAG_INGESTION):
50
- if (rag_option == RAG_LANGCHAIN):
51
- #rag = LangChainRAG()
52
- #rag.ingestion(config)
53
- elif (rag_option == RAG_LLAMAINDEX):
54
- #rag = LlamaIndexRAG()
55
- #rag.ingestion(config)
56
-
57
- try:
58
- #rag = LangChainRAG()
59
- #completion, callback = rag.rag_chain(config, prompt)
60
- #result = completion["result"]
61
- elif (rag_option == RAG_LLAMAINDEX):
62
- #rag = LlamaIndexRAG()
63
- #result, callback = rag.retrieval(config, prompt)
64
- else:
65
- #rag = LangChainRAG()
66
- #completion, callback = rag.llm_chain(config, prompt)
67
- #result = completion.generations[0][0].text
68
- except Exception as e:
69
- err_msg = e
70
-
71
- raise gr.Error(e)
72
- finally:
73
- del os.environ["OPENAI_API_KEY"]
74
- """
75
 
76
  return result
77
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
  import logging, os, sys, threading
4
 
5
+ from datasets import load_dataset
6
  from dotenv import load_dotenv, find_dotenv
7
+ from utils import process_records, connect_to_database, setup_vector_search_index
8
+
9
+ from pydantic import BaseModel
10
+ from typing import Optional
11
+
12
+ from IPython.display import display, HTML
13
 
14
  lock = threading.Lock()
15
 
 
24
  logging.basicConfig(stream = sys.stdout, level = logging.INFO)
25
  logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
26
 
27
+ def vector_search(user_query, db, collection, additional_stages=[], vector_index="vector_index_text"):
28
+ """
29
+ Perform a vector search in the MongoDB collection based on the user query.
30
+
31
+ Args:
32
+ user_query (str): The user's query string.
33
+ db (MongoClient.database): The database object.
34
+ collection (MongoCollection): The MongoDB collection to search.
35
+ additional_stages (list): Additional aggregation stages to include in the pipeline.
36
+
37
+ Returns:
38
+ list: A list of matching documents.
39
+ """
40
+
41
+ # Generate embedding for the user query
42
+ query_embedding = custom_utils.get_embedding(user_query)
43
+
44
+ if query_embedding is None:
45
+ return "Invalid query or embedding generation failed."
46
+
47
+ # Define the vector search stage
48
+ vector_search_stage = {
49
+ "$vectorSearch": {
50
+ "index": vector_index, # specifies the index to use for the search
51
+ "queryVector": query_embedding, # the vector representing the query
52
+ "path": "text_embeddings", # field in the documents containing the vectors to search against
53
+ "numCandidates": 150, # number of candidate matches to consider
54
+ "limit": 20, # return top 20 matches
55
+ }
56
+ }
57
+
58
+ # Define the aggregate pipeline with the vector search stage and additional stages
59
+ pipeline = [vector_search_stage] + additional_stages
60
+
61
+ # Execute the search
62
+ results = collection.aggregate(pipeline)
63
+
64
+ explain_query_execution = db.command( # sends a database command directly to the MongoDB server
65
+ 'explain', { # return information about how MongoDB executes a query or command without actually running it
66
+ 'aggregate': collection.name, # specifies the name of the collection on which the aggregation is performed
67
+ 'pipeline': pipeline, # the aggregation pipeline to analyze
68
+ 'cursor': {} # indicates that default cursor behavior should be used
69
+ },
70
+ verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
71
+
72
+ vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
73
+ millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
74
+
75
+ print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")
76
+
77
+ return list(results)
78
+
79
+ class SearchResultItem(BaseModel):
80
+ name: str
81
+ accommodates: Optional[int] = None
82
+ bedrooms: Optional[int] = None
83
+ address: custom_utils.Address
84
+ space: str = None
85
+
86
+ def handle_user_query(query, db, collection, stages=[], vector_index="vector_index_text"):
87
+ # Assuming vector_search returns a list of dictionaries with keys 'title' and 'plot'
88
+ get_knowledge = vector_search(query, db, collection, stages, vector_index)
89
+
90
+ # Check if there are any results
91
+ if not get_knowledge:
92
+ return "No results found.", "No source information available."
93
+
94
+ # Convert search results into a list of SearchResultItem models
95
+ search_results_models = [
96
+ SearchResultItem(**result)
97
+ for result in get_knowledge
98
+ ]
99
+
100
+ # Convert search results into a DataFrame for better rendering in Jupyter
101
+ search_results_df = pd.DataFrame([item.dict() for item in search_results_models])
102
+
103
+ # Generate system response using OpenAI's completion
104
+ completion = custom_utils.openai.chat.completions.create(
105
+ model="gpt-3.5-turbo",
106
+ messages=[
107
+ {
108
+ "role": "system",
109
+ "content": "You are a airbnb listing recommendation system."},
110
+ {
111
+ "role": "user",
112
+ "content": f"Answer this user query: {query} with the following context:\n{search_results_df}"
113
+ }
114
+ ]
115
+ )
116
+
117
+ system_response = completion.choices[0].message.content
118
+
119
+ # Print User Question, System Response, and Source Information
120
+ print(f"- User Question:\n{query}\n")
121
+ print(f"- System Response:\n{system_response}\n")
122
+
123
+ # Display the DataFrame as an HTML table
124
+ display(HTML(search_results_df.to_html()))
125
+
126
+ # Return structured response and source info as a string
127
+ return system_response
128
+
129
  def invoke(openai_api_key, prompt, rag_option):
130
  if not openai_api_key:
131
  raise gr.Error("OpenAI API Key is required.")
 
143
  and not too far from resturants, can you recommend a place?
144
  Include a reason as to why you've chosen your selection.
145
  """
146
+ dataset = load_dataset("MongoDB/airbnb_embeddings", streaming=True, split="train")
147
+ dataset = dataset.take(100)
148
+ # Convert the dataset to a pandas dataframe
149
+ dataset_df = pd.DataFrame(dataset)
150
+ dataset_df.head(5)
151
+ print("Columns:", dataset_df.columns)
152
+
153
+ listings = process_records(dataset_df)
154
+
155
+ db, collection = connect_to_database()
156
+ collection.delete_many({})
157
+ collection.insert_many(listings)
158
+ print("Data ingestion into MongoDB completed")
159
 
160
+ setup_vector_search_index(collection=collection)
161
+
162
+ search_path = "address.country"
163
+
164
+ # Create a match stage
165
+ match_stage = {
166
+ "$match": {
167
+ search_path: re.compile(r"United States"),
168
+ "accommodates": { "$gt": 1, "$lt": 5}
169
+ }
170
+ }
171
 
172
+ additional_stages = [match_stage]
173
+
174
+ query = """
175
+ I want to stay in a place that's warm and friendly,
176
+ and not too far from resturants, can you recommend a place?
177
+ Include a reason as to why you've chosen your selection"
178
  """
179
+ result = handle_user_query(query, db, collection, additional_stages)
180
+ ###
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  return result
183