advanced-rag / app.py
bstraehle's picture
Update app.py
8c19f4a verified
raw
history blame
4.88 kB
import gradio as gr
import logging, os, sys, threading
from custom_utils import connect_to_database, rag_ingestion, rag_retrieval, rag_inference
lock = threading.Lock()
RAG_INGESTION = True
RAG_OFF = "Off"
RAG_NAIVE = "Naive RAG"
RAG_ADVANCED = "Advanced RAG"
logging.basicConfig(stream = sys.stdout, level = logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
def invoke(openai_api_key, prompt, rag_option):
if not openai_api_key:
raise gr.Error("OpenAI API Key is required.")
if not prompt:
raise gr.Error("Prompt is required.")
if not rag_option:
raise gr.Error("Retrieval-Augmented Generation is required.")
with lock:
db, collection = connect_to_database()
if (RAG_INGESTION):
return rag_ingestion(collection)
else:
### Pre-retrieval processing: index filter
### Post-retrieval processing: result filter
#match_stage = {
# "$match": {
# "accommodates": { "$eq": 2},
# "bedrooms": { "$eq": 1}
# }
#}
#additional_stages = [match_stage]
###
"""
projection_stage = {
"$project": {
"_id": 0,
"name": 1,
"accommodates": 1,
"address.street": 1,
"address.government_area": 1,
"address.market": 1,
"address.country": 1,
"address.country_code": 1,
"address.location.type": 1,
"address.location.coordinates": 1,
"address.location.is_location_exact": 1,
"summary": 1,
"space": 1,
"neighborhood_overview": 1,
"notes": 1,
"score": {"$meta": "vectorSearchScore"}
}
}
additional_stages = [projection_stage]
"""
###
review_average_stage = {
"$addFields": {
"averageReviewScore": {
"$divide": [
{
"$add": [
"$review_scores.review_scores_accuracy",
"$review_scores.review_scores_cleanliness",
"$review_scores.review_scores_checkin",
"$review_scores.review_scores_communication",
"$review_scores.review_scores_location",
"$review_scores.review_scores_value",
]
},
6 # Divide by the number of review score types to get the average
]
},
# Calculate a score boost factor based on the number of reviews
"reviewCountBoost": "$number_of_reviews"
}
}
weighting_stage = {
"$addFields": {
"combinedScore": {
# Example formula that combines average review score and review count boost
"$add": [
{"$multiply": ["$averageReviewScore", 0.9]}, # Weighted average review score
{"$multiply": ["$reviewCountBoost", 0.1]} # Weighted review count boost
]
}
}
}
# Apply the combinedScore for sorting
sorting_stage_sort = {
"$sort": {"combinedScore": -1} # Descending order to boost higher combined scores
}
additional_stages = [review_average_stage, weighting_stage, sorting_stage_sort]
###
#additional_stages = []
###
search_results = rag_retrieval(openai_api_key, prompt, db, collection, additional_stages)
return rag_inference(openai_api_key, prompt, search_results)
gr.close_all()
PROMPT = "Recommend a place that's modern, spacious, and within walking distance from restaurants."
demo = gr.Interface(
fn = invoke,
inputs = [gr.Textbox(label = "OpenAI API Key", type = "password", lines = 1),
gr.Textbox(label = "Prompt", value = PROMPT, lines = 1),
gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_ADVANCED)],
outputs = [gr.Markdown(label = "Completion")],
title = "Context-Aware Reasoning Application",
description = os.environ["DESCRIPTION"]
)
demo.launch()