Spaces:
Running
Running
import gradio as gr | |
import logging, os, sys, threading | |
from custom_utils import connect_to_database, rag_ingestion, rag_retrieval, rag_inference | |
lock = threading.Lock() | |
RAG_INGESTION = False | |
RAG_OFF = "Off" | |
RAG_NAIVE = "Naive RAG" | |
RAG_ADVANCED = "Advanced RAG" | |
logging.basicConfig(stream = sys.stdout, level = logging.INFO) | |
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout)) | |
def invoke(openai_api_key, prompt, rag_option): | |
if not openai_api_key: | |
raise gr.Error("OpenAI API Key is required.") | |
if not prompt: | |
raise gr.Error("Prompt is required.") | |
if not rag_option: | |
raise gr.Error("Retrieval-Augmented Generation is required.") | |
with lock: | |
db, collection = connect_to_database() | |
if (RAG_INGESTION): | |
return rag_ingestion(collection) | |
else: | |
### Pre-retrieval processing: index filter | |
### Post-retrieval processing: result filter | |
#match_stage = { | |
# "$match": { | |
# "accommodates": { "$eq": 2}, | |
# "bedrooms": { "$eq": 1} | |
# } | |
#} | |
#additional_stages = [match_stage] | |
### | |
""" | |
projection_stage = { | |
"$project": { | |
"_id": 0, | |
"name": 1, | |
"accommodates": 1, | |
"address.street": 1, | |
"address.government_area": 1, | |
"address.market": 1, | |
"address.country": 1, | |
"address.country_code": 1, | |
"address.location.type": 1, | |
"address.location.coordinates": 1, | |
"address.location.is_location_exact": 1, | |
"summary": 1, | |
"space": 1, | |
"neighborhood_overview": 1, | |
"notes": 1, | |
"score": {"$meta": "vectorSearchScore"} | |
} | |
} | |
additional_stages = [projection_stage] | |
""" | |
### | |
review_average_stage = { | |
"$addFields": { | |
"averageReviewScore": { | |
"$divide": [ | |
{ | |
"$add": [ | |
"$review_scores.review_scores_accuracy", | |
"$review_scores.review_scores_cleanliness", | |
"$review_scores.review_scores_checkin", | |
"$review_scores.review_scores_communication", | |
"$review_scores.review_scores_location", | |
"$review_scores.review_scores_value", | |
] | |
}, | |
6 # Divide by the number of review score types to get the average | |
] | |
}, | |
# Calculate a score boost factor based on the number of reviews | |
"reviewCountBoost": "$number_of_reviews" | |
} | |
} | |
weighting_stage = { | |
"$addFields": { | |
"combinedScore": { | |
# Example formula that combines average review score and review count boost | |
"$add": [ | |
{"$multiply": ["$averageReviewScore", 0.9]}, # Weighted average review score | |
{"$multiply": ["$reviewCountBoost", 0.1]} # Weighted review count boost | |
] | |
} | |
} | |
} | |
# Apply the combinedScore for sorting | |
sorting_stage_sort = { | |
"$sort": {"combinedScore": -1} # Descending order to boost higher combined scores | |
} | |
additional_stages = [review_average_stage, weighting_stage, sorting_stage_sort] | |
### | |
#additional_stages = [] | |
### | |
search_results = rag_retrieval(openai_api_key, prompt, db, collection, additional_stages) | |
return rag_inference(openai_api_key, prompt, search_results) | |
gr.close_all() | |
PROMPT = "Recommend a place that's modern, spacious, and within walking distance from restaurants." | |
demo = gr.Interface( | |
fn = invoke, | |
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1), | |
gr.Textbox(label = "Prompt", value = PROMPT, lines = 1), | |
gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_ADVANCED)], | |
outputs = [gr.Markdown(label = "Completion")], | |
title = "Context-Aware Reasoning Application", | |
description = os.environ["DESCRIPTION"] | |
) | |
demo.launch() |