advanced-rag / app.py
bstraehle's picture
Update app.py
03159c2 verified
raw
history blame
3.09 kB
import gradio as gr
import pandas as pd
import logging, os, re, sys, threading
from datasets import load_dataset
from dotenv import load_dotenv, find_dotenv
from custom_utils import process_records, connect_to_database, handle_user_query
from pydantic import BaseModel
from typing import Optional
from IPython.display import display, HTML
lock = threading.Lock()
_ = load_dotenv(find_dotenv())
RAG_INGESTION = False # load, split, embed, and store documents
RAG_OFF = "Off"
RAG_NAIVE = "Naive RAG"
RAG_ADVANCED = "Advanced RAG"
logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
def invoke(openai_api_key, prompt, rag_option):
if not openai_api_key:
raise gr.Error("OpenAI API Key is required.")
if not prompt:
raise gr.Error("Prompt is required.")
if not rag_option:
raise gr.Error("Retrieval-Augmented Generation is required.")
with lock:
prompt = """
I want to stay in a place that's warm and friendly,
and not too far from resturants, can you recommend a place?
Include a reason as to why you've chosen your selection.
"""
"""
print("111")
dataset = load_dataset("MongoDB/airbnb_embeddings", streaming=True, split="train")
#dataset = dataset.take(100)
print("222")
# Convert the dataset to a pandas dataframe
dataset_df = pd.DataFrame(dataset)
#dataset_df.head(5)
#print("Columns:", dataset_df.columns)
listings = process_records(dataset_df)
print("333")
print("444")
db, collection = connect_to_database()
collection.delete_many({})
collection.insert_many(listings)
print("Data ingestion into MongoDB completed")
print("555")
print("666")
# Not available in free tier
#setup_vector_search_index(collection=collection)
"""
"""
print("777")
search_path = "address.country"
print("888")
# Create a match stage
match_stage = {
"$match": {
search_path: re.compile(r"United States"),
"accommodates": { "$gt": 1, "$lt": 3}
}
}
print("999")
additional_stages = [match_stage]
"""
print("000")
#result = handle_user_query(openai_api_key, query, db, collection, additional_stages)
return handle_user_prompt(openai_api_key, prompt, db, collection)
gr.close_all()
demo = gr.Interface(
fn = invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Textbox(label = "Prompt", value = "TODO", lines = 1),
gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_ADVANCED)],
outputs = [gr.Textbox(label = "Completion")],
title = "Context-Aware Reasoning Application",
description = os.environ["DESCRIPTION"]
)
demo.launch()