Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import logging, os, re, sys, threading | |
from datasets import load_dataset | |
from dotenv import load_dotenv, find_dotenv | |
from custom_utils import process_records, connect_to_database, handle_user_query | |
from pydantic import BaseModel | |
from typing import Optional | |
from IPython.display import display, HTML | |
lock = threading.Lock() | |
_ = load_dotenv(find_dotenv()) | |
RAG_INGESTION = False # load, split, embed, and store documents | |
RAG_OFF = "Off" | |
RAG_NAIVE = "Naive RAG" | |
RAG_ADVANCED = "Advanced RAG" | |
logging.basicConfig(stream = sys.stdout, level = logging.INFO) | |
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout)) | |
def invoke(openai_api_key, prompt, rag_option): | |
if not openai_api_key: | |
raise gr.Error("OpenAI API Key is required.") | |
if not prompt: | |
raise gr.Error("Prompt is required.") | |
if not rag_option: | |
raise gr.Error("Retrieval-Augmented Generation is required.") | |
with lock: | |
prompt = """ | |
I want to stay in a place that's warm and friendly, | |
and not too far from resturants, can you recommend a place? | |
Include a reason as to why you've chosen your selection. | |
""" | |
""" | |
print("111") | |
dataset = load_dataset("MongoDB/airbnb_embeddings", streaming=True, split="train") | |
#dataset = dataset.take(100) | |
print("222") | |
# Convert the dataset to a pandas dataframe | |
dataset_df = pd.DataFrame(dataset) | |
#dataset_df.head(5) | |
#print("Columns:", dataset_df.columns) | |
listings = process_records(dataset_df) | |
print("333") | |
print("444") | |
db, collection = connect_to_database() | |
collection.delete_many({}) | |
collection.insert_many(listings) | |
print("Data ingestion into MongoDB completed") | |
print("555") | |
print("666") | |
# Not available in free tier | |
#setup_vector_search_index(collection=collection) | |
""" | |
""" | |
print("777") | |
search_path = "address.country" | |
print("888") | |
# Create a match stage | |
match_stage = { | |
"$match": { | |
search_path: re.compile(r"United States"), | |
"accommodates": { "$gt": 1, "$lt": 3} | |
} | |
} | |
print("999") | |
additional_stages = [match_stage] | |
""" | |
print("000") | |
#result = handle_user_query(openai_api_key, query, db, collection, additional_stages) | |
return handle_user_prompt(openai_api_key, prompt, db, collection) | |
gr.close_all() | |
demo = gr.Interface( | |
fn = invoke, | |
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), | |
gr.Textbox(label = "Prompt", value = "TODO", lines = 1), | |
gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_ADVANCED)], | |
outputs = [gr.Textbox(label = "Completion")], | |
title = "Context-Aware Reasoning Application", | |
description = os.environ["DESCRIPTION"] | |
) | |
demo.launch() |