File size: 11,818 Bytes
8e4faf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ba6f5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import gradio as gr
import logging
import time
from generator.compute_metrics import get_attributes_text
from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
from config import AppConfig, ConfigConstants
from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
from generator.document_utils import get_logs, initialize_logging
from retriever.load_selected_datasets import load_selected_datasets 

def launch_gradio(config : AppConfig):
    """

    Launch the Gradio app with pre-initialized objects.

    """
    initialize_logging()

    # **🔹 Always get the latest loaded datasets**
    config.detect_loaded_datasets()

    def update_logs_periodically():
        while True:
            time.sleep(2)  # Wait for 2 seconds
            yield get_logs() 

    def answer_question(query, state):
        try:
            # Ensure vector store is updated before use
            if config.vector_store is None:
                return "Please load a dataset first.", state
            
            # Generate response using the passed objects
            response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
            
            # Update state with the response and source documents
            state["query"] = query
            state["response"] = response
            state["source_docs"] = source_docs
            
            response_text = f"Response from Model ({config.gen_llm.name}) : {response}\n\n"
            return response_text, state
        except Exception as e:
            logging.error(f"Error processing query: {e}")
            return f"An error occurred: {e}", state

    def compute_metrics(state):
        try:
            logging.info(f"Computing metrics")
            
            # Retrieve response and source documents from state
            response = state.get("response", "")
            source_docs = state.get("source_docs", {})
            query = state.get("query", "")

            # Generate metrics using the passed objects
            attributes, metrics = generate_metrics(config.val_llm, response, source_docs, query, 1)
            
            attributes_text = get_attributes_text(attributes)

            metrics_text = ""
            for key, value in metrics.items():
                if key != 'response':
                    metrics_text += f"{key}: {value}\n"
            
            return attributes_text, metrics_text
        except Exception as e:
            logging.error(f"Error computing metrics: {e}")
            return f"An error occurred: {e}", ""

    def reinitialize_llm(model_type, model_name):
        """Reinitialize the specified LLM (generation or validation) and return updated model info."""
        if model_name.strip():  # Only update if input is not empty
            if model_type == "generation":
                config.gen_llm = initialize_generation_llm(model_name)
            elif model_type == "validation":
                config.val_llm = initialize_validation_llm(model_name)

        return get_updated_model_info()

    def get_updated_model_info():
        loaded_datasets_str = ", ".join(config.loaded_datasets) if config.loaded_datasets else "None"
        """Generate and return the updated model information string."""
        return (
            f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
            f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
            f"Re-ranking LLM: {ConfigConstants.RE_RANKER_MODEL_NAME}\n"
            f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
            f"Loaded Datasets: {loaded_datasets_str}\n"
        )

    # Wrappers for event listeners
    def reinitialize_gen_llm(gen_llm_name):
        return reinitialize_llm("generation", gen_llm_name)

    def reinitialize_val_llm(val_llm_name):
        return reinitialize_llm("validation", val_llm_name)
    
    # Function to update query input when a question is selected from the dropdown
    def update_query_input(selected_question):
        return selected_question
        
    # Define Gradio Blocks layout
    with gr.Blocks() as interface:
        interface.title = "Real Time RAG Pipeline Q&A"
        gr.Markdown("""

            # Real Time RAG Pipeline Q&A

            The **Retrieval-Augmented Generation (RAG) Pipeline** combines retrieval-based and generative AI models to provide accurate and context-aware answers to your questions. 

            It retrieves relevant documents from a dataset (e.g., COVIDQA, TechQA, FinQA) and uses a generative model to synthesize a response. 

            Metrics are computed to evaluate the quality of the response and the retrieval process.

            """)
        # Model Configuration
        with gr.Accordion("System Information", open=False):
            with gr.Accordion("DataSet", open=False):
                with gr.Row():
                    dataset_selector = gr.CheckboxGroup(ConfigConstants.DATA_SET_NAMES, label="Select Datasets to Load")
                    load_button = gr.Button("Load", scale= 0)

            with gr.Row():
                # Column for Generation Model Dropdown
                with gr.Column(scale=1):
                    new_gen_llm_input = gr.Dropdown(
                        label="Generation Model",
                        choices=ConfigConstants.GENERATION_MODELS,
                        value=ConfigConstants.GENERATION_MODELS[0] if ConfigConstants.GENERATION_MODELS else None,
                        interactive=True,
                        info="Select the generative model for response generation."
                    )
                
                # Column for Validation Model Dropdown
                with gr.Column(scale=1):
                    new_val_llm_input = gr.Dropdown(
                        label="Validation Model",
                        choices=ConfigConstants.VALIDATION_MODELS,
                        value=ConfigConstants.VALIDATION_MODELS[0] if ConfigConstants.VALIDATION_MODELS else None,
                        interactive=True,
                        info="Select the model for validating the response quality."
                    )
                
                # Column for Model Information
                with gr.Column(scale=2):
                    model_info_display = gr.Textbox(
                        value=get_updated_model_info(),  # Use the helper function
                        label="Model Configuration",
                        interactive=False,  # Read-only textbox
                        lines=5 
                    )
        
        # Query Section
        gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.")
        all_questions = [
            "Does the ignition button have multiple modes?",
            "Why does the other instance of my multi-instance qmgr seem to hang after a failover? Queue manager will not start after failover.",
            "Is one party required to deposit its source code into escrow with a third party, which can be released to the counterparty upon the occurrence of certain events (bankruptcy,  insolvency, etc.)?",
            "Explain the concept of blockchain.",
            "What is the capital of France?",
            "Do Surface Porosity and Pore Size Influence Mechanical Properties and Cellular Response to PEEK??",
            "How does a vaccine work?",
            "Tell me the step-by-step instruction for front-door installation.",
            "What are the risk factors for heart disease?",
            "What is the % change in total property and equipment from 2018 to 2019?",
            # Add more questions as needed
        ]

        # Subset of questions to display as examples
        example_questions = [
            "When was the first case of COVID-19 identified?",
            "What are the ages of the patients in this study?",
            "Why cant I load and AEL when using IE 11 JRE 8 Application Blocked by Java Security",
            "Explain the concept of blockchain.",
            "What is the capital of France?",
            "What was the change in Current deferred income?"
        ]  
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    query_input = gr.Textbox(
                        label="Ask a question ",
                        placeholder="Type your query here or select from examples/dropdown",
                        lines=2
                    )
                with gr.Row():
                    submit_button = gr.Button("Submit", variant="primary", scale=0)
                    clear_query_button = gr.Button("Clear", scale=0)
            with gr.Column():
                gr.Examples(
                    examples=example_questions,  # Make sure the variable name matches
                    inputs=query_input,
                    label="Try these examples:"
                )
                question_dropdown = gr.Dropdown(
                    label="",
                    choices=all_questions,
                    interactive=True,
                    info="Choose a question from the dropdown to populate the query box."
                )
        
        # Attach event listener to dropdown
        question_dropdown.change(
            fn=update_query_input,
            inputs=question_dropdown,
            outputs=query_input
        )

        # Response and Metrics
        with gr.Row():
            answer_output = gr.Textbox(label="Response", placeholder="Response will appear here", lines=2)
        
        with gr.Row():
            compute_metrics_button = gr.Button("Compute metrics", variant="primary" , scale = 0)
            attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
            metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
       
        # State to store response and source documents
        state = gr.State(value={"query": "","response": "", "source_docs": {}})

        # Pass config to update vector store
        load_button.click(lambda datasets: (load_selected_datasets(datasets, config), get_updated_model_info()), inputs=dataset_selector, outputs=model_info_display)
        # Attach event listeners to update model info on change
        new_gen_llm_input.change(reinitialize_gen_llm, inputs=new_gen_llm_input, outputs=model_info_display)
        new_val_llm_input.change(reinitialize_val_llm, inputs=new_val_llm_input, outputs=model_info_display)

        # Define button actions
        submit_button.click(
            fn=answer_question,
            inputs=[query_input, state],
            outputs=[answer_output, state]
        )
        clear_query_button.click(fn=lambda: "", outputs=[query_input])  # Clear query input
        compute_metrics_button.click(
            fn=compute_metrics,
            inputs=[state],
            outputs=[attr_output, metrics_output]
        )
        
        # Section to display logs
        with gr.Accordion("View Live Logs", open=False):
            with gr.Row():
                log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10 , every=2)  # Log section

        # Update UI when logs_state changes
        interface.queue() 
        interface.load(update_logs_periodically, outputs=log_section)
        interface.load(get_updated_model_info, outputs=model_info_display)

    interface.launch()