gourisankar85 commited on
Commit
5ba6f5c
·
verified ·
1 Parent(s): 1f7a441

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -181
app.py CHANGED
@@ -1,181 +1,156 @@
1
- import gradio as gr
2
- import logging
3
- import threading
4
- import time
5
- from generator.compute_metrics import get_attributes_text
6
- from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
7
- from config import AppConfig, ConfigConstants
8
- from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
9
-
10
- def launch_gradio(config : AppConfig):
11
- """
12
- Launch the Gradio app with pre-initialized objects.
13
- """
14
- logger = logging.getLogger()
15
- logger.setLevel(logging.INFO)
16
-
17
- # Create a list to store logs
18
- logs = []
19
-
20
- # Custom log handler to capture logs and add them to the logs list
21
- class LogHandler(logging.Handler):
22
- def emit(self, record):
23
- log_entry = self.format(record)
24
- logs.append(log_entry)
25
-
26
- # Add custom log handler to the logger
27
- log_handler = LogHandler()
28
- log_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
29
- logger.addHandler(log_handler)
30
-
31
- def log_updater():
32
- """Background function to add logs."""
33
- while True:
34
- time.sleep(2) # Update logs every 2 seconds
35
- pass # Log capture is now handled by the logging system
36
-
37
- def get_logs():
38
- """Retrieve logs for display."""
39
- return "\n".join(logs[-50:]) # Only show the last 50 logs for example
40
-
41
- # Start the logging thread
42
- threading.Thread(target=log_updater, daemon=True).start()
43
-
44
- def answer_question(query, state):
45
- try:
46
- # Generate response using the passed objects
47
- response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
48
-
49
- # Update state with the response and source documents
50
- state["query"] = query
51
- state["response"] = response
52
- state["source_docs"] = source_docs
53
-
54
- response_text = f"Response: {response}\n\n"
55
- return response_text, state
56
- except Exception as e:
57
- logging.error(f"Error processing query: {e}")
58
- return f"An error occurred: {e}", state
59
-
60
- def compute_metrics(state):
61
- try:
62
- logging.info(f"Computing metrics")
63
-
64
- # Retrieve response and source documents from state
65
- response = state.get("response", "")
66
- source_docs = state.get("source_docs", {})
67
- query = state.get("query", "")
68
-
69
- # Generate metrics using the passed objects
70
- attributes, metrics = generate_metrics(config.val_llm, response, source_docs, query, 1)
71
-
72
- attributes_text = get_attributes_text(attributes)
73
-
74
- metrics_text = "Metrics:\n"
75
- for key, value in metrics.items():
76
- if key != 'response':
77
- metrics_text += f"{key}: {value}\n"
78
-
79
- return attributes_text, metrics_text
80
- except Exception as e:
81
- logging.error(f"Error computing metrics: {e}")
82
- return f"An error occurred: {e}", ""
83
-
84
- def reinitialize_gen_llm(gen_llm_name):
85
- """Reinitialize the generation LLM and return updated model info."""
86
- if gen_llm_name.strip(): # Only update if input is not empty
87
- config.gen_llm = initialize_generation_llm(gen_llm_name)
88
-
89
- # Return updated model information
90
- updated_model_info = (
91
- f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
92
- f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
93
- f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
94
- )
95
- return updated_model_info
96
-
97
- def reinitialize_val_llm(val_llm_name):
98
- """Reinitialize the generation LLM and return updated model info."""
99
- if val_llm_name.strip(): # Only update if input is not empty
100
- config.val_llm = initialize_validation_llm(val_llm_name)
101
-
102
- # Return updated model information
103
- updated_model_info = (
104
- f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
105
- f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
106
- f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
107
- )
108
- return updated_model_info
109
-
110
- # Define Gradio Blocks layout
111
- with gr.Blocks() as interface:
112
- interface.title = "Real Time RAG Pipeline Q&A"
113
- gr.Markdown("### Real Time RAG Pipeline Q&A") # Heading
114
-
115
- # Textbox for new generation LLM name
116
- with gr.Row():
117
- new_gen_llm_input = gr.Textbox(label="New Generation LLM Name", placeholder="Enter LLM name to update")
118
- update_gen_llm_button = gr.Button("Update Generation LLM")
119
- new_val_llm_input = gr.Textbox(label="New Validation LLM Name", placeholder="Enter LLM name to update")
120
- update_val_llm_button = gr.Button("Update Validation LLM")
121
-
122
- # Section to display LLM names
123
- with gr.Row():
124
- model_info = f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
125
- model_info += f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
126
- model_info += f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
127
- model_info_display = gr.Textbox(value=model_info, label="Model Information", interactive=False) # Read-only textbox
128
-
129
- # State to store response and source documents
130
- state = gr.State(value={"query": "","response": "", "source_docs": {}})
131
- gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
132
- with gr.Row():
133
- query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
134
- with gr.Row():
135
- submit_button = gr.Button("Submit", variant="primary") # Submit button
136
- clear_query_button = gr.Button("Clear") # Clear button
137
- with gr.Row():
138
- answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
139
-
140
- with gr.Row():
141
- compute_metrics_button = gr.Button("Compute metrics", variant="primary")
142
- attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
143
- metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
144
-
145
- #with gr.Row():
146
-
147
- # Define button actions
148
- submit_button.click(
149
- fn=answer_question,
150
- inputs=[query_input, state],
151
- outputs=[answer_output, state]
152
- )
153
- clear_query_button.click(fn=lambda: "", outputs=[query_input]) # Clear query input
154
- compute_metrics_button.click(
155
- fn=compute_metrics,
156
- inputs=[state],
157
- outputs=[attr_output, metrics_output]
158
- )
159
-
160
- update_gen_llm_button.click(
161
- fn=reinitialize_gen_llm,
162
- inputs=[new_gen_llm_input],
163
- outputs=[model_info_display] # Update the displayed model info
164
- )
165
-
166
- update_val_llm_button.click(
167
- fn=reinitialize_val_llm,
168
- inputs=[new_val_llm_input],
169
- outputs=[model_info_display] # Update the displayed model info
170
- )
171
-
172
- # Section to display logs
173
- with gr.Row():
174
- start_log_button = gr.Button("Start Log Update", elem_id="start_btn") # Button to start log updates
175
- with gr.Row():
176
- log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10) # Log section
177
-
178
- # Set button click to trigger log updates
179
- start_log_button.click(fn=get_logs, outputs=log_section)
180
-
181
- interface.launch(share=True)
 
1
+ import gradio as gr
2
+ import logging
3
+ import threading
4
+ import time
5
+ from generator.compute_metrics import get_attributes_text
6
+ from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
7
+ from config import AppConfig, ConfigConstants
8
+ from generator.initialize_llm import initialize_generation_llm, initialize_validation_llm
9
+ from generator.document_utils import get_logs, initialize_logging
10
+
11
+ def launch_gradio(config : AppConfig):
12
+ """
13
+ Launch the Gradio app with pre-initialized objects.
14
+ """
15
+ initialize_logging()
16
+
17
+ def update_logs_periodically():
18
+ while True:
19
+ time.sleep(2) # Wait for 2 seconds
20
+ yield get_logs()
21
+
22
+ def answer_question(query, state):
23
+ try:
24
+ # Generate response using the passed objects
25
+ response, source_docs = retrieve_and_generate_response(config.gen_llm, config.vector_store, query)
26
+
27
+ # Update state with the response and source documents
28
+ state["query"] = query
29
+ state["response"] = response
30
+ state["source_docs"] = source_docs
31
+
32
+ response_text = f"Response: {response}\n\n"
33
+ return response_text, state
34
+ except Exception as e:
35
+ logging.error(f"Error processing query: {e}")
36
+ return f"An error occurred: {e}", state
37
+
38
+ def compute_metrics(state):
39
+ try:
40
+ logging.info(f"Computing metrics")
41
+
42
+ # Retrieve response and source documents from state
43
+ response = state.get("response", "")
44
+ source_docs = state.get("source_docs", {})
45
+ query = state.get("query", "")
46
+
47
+ # Generate metrics using the passed objects
48
+ attributes, metrics = generate_metrics(config.val_llm, response, source_docs, query, 1)
49
+
50
+ attributes_text = get_attributes_text(attributes)
51
+
52
+ metrics_text = "Metrics:\n"
53
+ for key, value in metrics.items():
54
+ if key != 'response':
55
+ metrics_text += f"{key}: {value}\n"
56
+
57
+ return attributes_text, metrics_text
58
+ except Exception as e:
59
+ logging.error(f"Error computing metrics: {e}")
60
+ return f"An error occurred: {e}", ""
61
+
62
+ def reinitialize_llm(model_type, model_name):
63
+ """Reinitialize the specified LLM (generation or validation) and return updated model info."""
64
+ if model_name.strip(): # Only update if input is not empty
65
+ if model_type == "generation":
66
+ config.gen_llm = initialize_generation_llm(model_name)
67
+ elif model_type == "validation":
68
+ config.val_llm = initialize_validation_llm(model_name)
69
+
70
+ return get_updated_model_info()
71
+
72
+ def get_updated_model_info():
73
+ """Generate and return the updated model information string."""
74
+ return (
75
+ f"Embedding Model: {ConfigConstants.EMBEDDING_MODEL_NAME}\n"
76
+ f"Generation LLM: {config.gen_llm.name if hasattr(config.gen_llm, 'name') else 'Unknown'}\n"
77
+ f"Validation LLM: {config.val_llm.name if hasattr(config.val_llm, 'name') else 'Unknown'}\n"
78
+ )
79
+
80
+ # Wrappers for event listeners
81
+ def reinitialize_gen_llm(gen_llm_name):
82
+ return reinitialize_llm("generation", gen_llm_name)
83
+
84
+ def reinitialize_val_llm(val_llm_name):
85
+ return reinitialize_llm("validation", val_llm_name)
86
+
87
+ # Define Gradio Blocks layout
88
+ with gr.Blocks() as interface:
89
+ interface.title = "Real Time RAG Pipeline Q&A"
90
+ gr.Markdown("# Real Time RAG Pipeline Q&A") # Heading
91
+
92
+ # Textbox for new generation LLM name
93
+ with gr.Row():
94
+ new_gen_llm_input = gr.Dropdown(
95
+ label="Generation Model",
96
+ choices=ConfigConstants.GENERATION_MODELS, # Directly use the list
97
+ value=ConfigConstants.GENERATION_MODELS[0] if ConfigConstants.GENERATION_MODELS else None, # First value dynamically
98
+ interactive=True
99
+ )
100
+
101
+ new_val_llm_input = gr.Dropdown(
102
+ label="Validation Model",
103
+ choices=ConfigConstants.VALIDATION_MODELS, # Directly use the list
104
+ value=ConfigConstants.VALIDATION_MODELS[0] if ConfigConstants.VALIDATION_MODELS else None, # First value dynamically
105
+ interactive=True
106
+ )
107
+
108
+ model_info_display = gr.Textbox(
109
+ value=get_updated_model_info(), # Use the helper function
110
+ label="System Information",
111
+ interactive=False # Read-only textbox
112
+ )
113
+
114
+ # State to store response and source documents
115
+ state = gr.State(value={"query": "","response": "", "source_docs": {}})
116
+ gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
117
+ with gr.Row():
118
+ query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
119
+ with gr.Row():
120
+ submit_button = gr.Button("Submit", variant="primary", scale = 0) # Submit button
121
+ clear_query_button = gr.Button("Clear", scale = 0) # Clear button
122
+ with gr.Row():
123
+ answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
124
+
125
+ with gr.Row():
126
+ compute_metrics_button = gr.Button("Compute metrics", variant="primary" , scale = 0)
127
+ attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
128
+ metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
129
+
130
+ #with gr.Row():
131
+ # Attach event listeners to update model info on change
132
+ new_gen_llm_input.change(reinitialize_gen_llm, inputs=new_gen_llm_input, outputs=model_info_display)
133
+ new_val_llm_input.change(reinitialize_val_llm, inputs=new_val_llm_input, outputs=model_info_display)
134
+
135
+ # Define button actions
136
+ submit_button.click(
137
+ fn=answer_question,
138
+ inputs=[query_input, state],
139
+ outputs=[answer_output, state]
140
+ )
141
+ clear_query_button.click(fn=lambda: "", outputs=[query_input]) # Clear query input
142
+ compute_metrics_button.click(
143
+ fn=compute_metrics,
144
+ inputs=[state],
145
+ outputs=[attr_output, metrics_output]
146
+ )
147
+
148
+ # Section to display logs
149
+ with gr.Row():
150
+ log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10 , every=2) # Log section
151
+
152
+ # Update UI when logs_state changes
153
+ interface.queue()
154
+ interface.load(update_logs_periodically, outputs=log_section)
155
+
156
+ interface.launch()