gourisankar85's picture
Upload app.py
8e4faf7 verified
import gradio as gr
import logging
import time
from generator.compute_metrics import get_attributes_text
from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
from config import AppConfig, ConfigConstants
from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
from generator.document_utils import get_logs, initialize_logging
from retriever.load_selected_datasets import load_selected_datasets
def launch_gradio(config : AppConfig):
"""
Launch the Gradio app with pre-initialized objects.
"""
initialize_logging()
# **🔹 Always get the latest loaded datasets**
config.detect_loaded_datasets()
def update_logs_periodically():
while True:
time.sleep(2) # Wait for 2 seconds
yield get_logs()
def answer_question(query, state):
try:
# Ensure vector store is updated before use
if config.vector_store is None:
return "Please load a dataset first.", state
# Generate response using the passed objects
response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
# Update state with the response and source documents
state["query"] = query
state["response"] = response
state["source_docs"] = source_docs
response_text = f"Response from Model ({config.gen_llm.name}) : {response}\n\n"
return response_text, state
except Exception as e:
logging.error(f"Error processing query: {e}")
return f"An error occurred: {e}", state
def compute_metrics(state):
try:
logging.info(f"Computing metrics")
# Retrieve response and source documents from state
response = state.get("response", "")
source_docs = state.get("source_docs", {})
query = state.get("query", "")
# Generate metrics using the passed objects
attributes, metrics = generate_metrics(config.val_llm, response, source_docs, query, 1)
attributes_text = get_attributes_text(attributes)
metrics_text = ""
for key, value in metrics.items():
if key != 'response':
metrics_text += f"{key}: {value}\n"
return attributes_text, metrics_text
except Exception as e:
logging.error(f"Error computing metrics: {e}")
return f"An error occurred: {e}", ""
def reinitialize_llm(model_type, model_name):
"""Reinitialize the specified LLM (generation or validation) and return updated model info."""
if model_name.strip(): # Only update if input is not empty
if model_type == "generation":
config.gen_llm = initialize_generation_llm(model_name)
elif model_type == "validation":
config.val_llm = initialize_validation_llm(model_name)
return get_updated_model_info()
def get_updated_model_info():
loaded_datasets_str = ", ".join(config.loaded_datasets) if config.loaded_datasets else "None"
"""Generate and return the updated model information string."""
return (
f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
f"Re-ranking LLM: {ConfigConstants.RE_RANKER_MODEL_NAME}\n"
f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
f"Loaded Datasets: {loaded_datasets_str}\n"
)
# Wrappers for event listeners
def reinitialize_gen_llm(gen_llm_name):
return reinitialize_llm("generation", gen_llm_name)
def reinitialize_val_llm(val_llm_name):
return reinitialize_llm("validation", val_llm_name)
# Function to update query input when a question is selected from the dropdown
def update_query_input(selected_question):
return selected_question
# Define Gradio Blocks layout
with gr.Blocks() as interface:
interface.title = "Real Time RAG Pipeline Q&A"
gr.Markdown("""
# Real Time RAG Pipeline Q&A
The **Retrieval-Augmented Generation (RAG) Pipeline** combines retrieval-based and generative AI models to provide accurate and context-aware answers to your questions.
It retrieves relevant documents from a dataset (e.g., COVIDQA, TechQA, FinQA) and uses a generative model to synthesize a response.
Metrics are computed to evaluate the quality of the response and the retrieval process.
""")
# Model Configuration
with gr.Accordion("System Information", open=False):
with gr.Accordion("DataSet", open=False):
with gr.Row():
dataset_selector = gr.CheckboxGroup(ConfigConstants.DATA_SET_NAMES, label="Select Datasets to Load")
load_button = gr.Button("Load", scale= 0)
with gr.Row():
# Column for Generation Model Dropdown
with gr.Column(scale=1):
new_gen_llm_input = gr.Dropdown(
label="Generation Model",
choices=ConfigConstants.GENERATION_MODELS,
value=ConfigConstants.GENERATION_MODELS[0] if ConfigConstants.GENERATION_MODELS else None,
interactive=True,
info="Select the generative model for response generation."
)
# Column for Validation Model Dropdown
with gr.Column(scale=1):
new_val_llm_input = gr.Dropdown(
label="Validation Model",
choices=ConfigConstants.VALIDATION_MODELS,
value=ConfigConstants.VALIDATION_MODELS[0] if ConfigConstants.VALIDATION_MODELS else None,
interactive=True,
info="Select the model for validating the response quality."
)
# Column for Model Information
with gr.Column(scale=2):
model_info_display = gr.Textbox(
value=get_updated_model_info(), # Use the helper function
label="Model Configuration",
interactive=False, # Read-only textbox
lines=5
)
# Query Section
gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.")
all_questions = [
"Does the ignition button have multiple modes?",
"Why does the other instance of my multi-instance qmgr seem to hang after a failover? Queue manager will not start after failover.",
"Is one party required to deposit its source code into escrow with a third party, which can be released to the counterparty upon the occurrence of certain events (bankruptcy, insolvency, etc.)?",
"Explain the concept of blockchain.",
"What is the capital of France?",
"Do Surface Porosity and Pore Size Influence Mechanical Properties and Cellular Response to PEEK??",
"How does a vaccine work?",
"Tell me the step-by-step instruction for front-door installation.",
"What are the risk factors for heart disease?",
"What is the % change in total property and equipment from 2018 to 2019?",
# Add more questions as needed
]
# Subset of questions to display as examples
example_questions = [
"When was the first case of COVID-19 identified?",
"What are the ages of the patients in this study?",
"Why cant I load and AEL when using IE 11 JRE 8 Application Blocked by Java Security",
"Explain the concept of blockchain.",
"What is the capital of France?",
"What was the change in Current deferred income?"
]
with gr.Row():
with gr.Column():
with gr.Row():
query_input = gr.Textbox(
label="Ask a question ",
placeholder="Type your query here or select from examples/dropdown",
lines=2
)
with gr.Row():
submit_button = gr.Button("Submit", variant="primary", scale=0)
clear_query_button = gr.Button("Clear", scale=0)
with gr.Column():
gr.Examples(
examples=example_questions, # Make sure the variable name matches
inputs=query_input,
label="Try these examples:"
)
question_dropdown = gr.Dropdown(
label="",
choices=all_questions,
interactive=True,
info="Choose a question from the dropdown to populate the query box."
)
# Attach event listener to dropdown
question_dropdown.change(
fn=update_query_input,
inputs=question_dropdown,
outputs=query_input
)
# Response and Metrics
with gr.Row():
answer_output = gr.Textbox(label="Response", placeholder="Response will appear here", lines=2)
with gr.Row():
compute_metrics_button = gr.Button("Compute metrics", variant="primary" , scale = 0)
attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
# State to store response and source documents
state = gr.State(value={"query": "","response": "", "source_docs": {}})
# Pass config to update vector store
load_button.click(lambda datasets: (load_selected_datasets(datasets, config), get_updated_model_info()), inputs=dataset_selector, outputs=model_info_display)
# Attach event listeners to update model info on change
new_gen_llm_input.change(reinitialize_gen_llm, inputs=new_gen_llm_input, outputs=model_info_display)
new_val_llm_input.change(reinitialize_val_llm, inputs=new_val_llm_input, outputs=model_info_display)
# Define button actions
submit_button.click(
fn=answer_question,
inputs=[query_input, state],
outputs=[answer_output, state]
)
clear_query_button.click(fn=lambda: "", outputs=[query_input]) # Clear query input
compute_metrics_button.click(
fn=compute_metrics,
inputs=[state],
outputs=[attr_output, metrics_output]
)
# Section to display logs
with gr.Accordion("View Live Logs", open=False):
with gr.Row():
log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10 , every=2) # Log section
# Update UI when logs_state changes
interface.queue()
interface.load(update_logs_periodically, outputs=log_section)
interface.load(get_updated_model_info, outputs=model_info_display)
interface.launch()