yashbyname commited on
Commit
c00355e
·
verified ·
1 Parent(s): f3a1e2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -14
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  import logging
4
  import json
5
  import tensorflow as tf
 
6
  import numpy as np
7
  from PIL import Image
8
  import io
@@ -18,7 +19,16 @@ external_user_id = 'plugin-1717464304'
18
  # Load the keras model
19
  def load_model():
20
  try:
21
- model = tf.keras.models.load_model('model_epoch_01.h5.keras')
 
 
 
 
 
 
 
 
 
22
  logger.info("Model loaded successfully")
23
  return model
24
  except Exception as e:
@@ -32,8 +42,13 @@ def preprocess_image(image):
32
  if isinstance(image, Image.Image):
33
  image = np.array(image)
34
 
 
 
 
 
 
 
35
  # Resize image to match model's expected input shape
36
- # Note: Adjust these dimensions to match your model's requirements
37
  target_size = (224, 224) # Change this to match your model's input size
38
  image = tf.image.resize(image, target_size)
39
 
@@ -83,9 +98,7 @@ def submit_query(session_id, query, image_analysis=None):
83
 
84
  structured_query = f"""
85
  Based on the following patient information and image analysis, provide a detailed medical analysis in JSON format:
86
-
87
  {query_with_image}
88
-
89
  Return only valid JSON with these fields:
90
  - diagnosis_details
91
  - probable_diagnoses (array)
@@ -119,6 +132,7 @@ def extract_json_from_answer(answer):
119
  return json.loads(answer)
120
  except json.JSONDecodeError:
121
  try:
 
122
  start_idx = answer.find('{')
123
  end_idx = answer.rfind('}') + 1
124
  if start_idx != -1 and end_idx != 0:
@@ -128,11 +142,31 @@ def extract_json_from_answer(answer):
128
  logger.error("Failed to parse JSON from response")
129
  raise
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # Initialize the model
132
- model = load_model()
 
 
 
 
133
 
134
  def gradio_interface(patient_info, image):
135
  try:
 
 
 
136
  # Process image if provided
137
  image_analysis = None
138
  if image is not None:
@@ -143,11 +177,7 @@ def gradio_interface(patient_info, image):
143
  prediction = model.predict(processed_image)
144
 
145
  # Format prediction results
146
- # Note: Adjust this based on your model's output format
147
- image_analysis = {
148
- "prediction": float(prediction[0][0]), # Adjust indexing based on your model's output
149
- "confidence": float(prediction[0][0]) * 100 # Convert to percentage
150
- }
151
 
152
  # Create chat session and submit query
153
  session_id = create_chat_session()
@@ -155,17 +185,17 @@ def gradio_interface(patient_info, image):
155
  json.dumps(image_analysis) if image_analysis else None)
156
 
157
  if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
158
- raise ValueError("Invalid response structure")
159
 
160
  # Extract and clean JSON from the response
161
  json_data = extract_json_from_answer(llm_response['data']['answer'])
162
 
163
- # Return clean JSON string
164
- return json.dumps(json_data)
165
 
166
  except Exception as e:
167
  logger.error(f"Error in gradio_interface: {str(e)}")
168
- return json.dumps({"error": str(e)})
169
 
170
  # Gradio interface
171
  iface = gr.Interface(
 
3
  import logging
4
  import json
5
  import tensorflow as tf
6
+ import tensorflow_hub as hub
7
  import numpy as np
8
  from PIL import Image
9
  import io
 
19
  # Load the keras model
20
  def load_model():
21
  try:
22
+ # Define custom objects dictionary
23
+ custom_objects = {
24
+ 'KerasLayer': hub.KerasLayer,
25
+ # Add any other custom layers your model might use
26
+ }
27
+
28
+ # Load model with custom object scope
29
+ with tf.keras.utils.custom_object_scope(custom_objects):
30
+ model = tf.keras.models.load_model('model_epoch_01.h5.keras')
31
+
32
  logger.info("Model loaded successfully")
33
  return model
34
  except Exception as e:
 
42
  if isinstance(image, Image.Image):
43
  image = np.array(image)
44
 
45
+ # Ensure image has 3 channels (RGB)
46
+ if len(image.shape) == 2: # Grayscale image
47
+ image = np.stack((image,) * 3, axis=-1)
48
+ elif len(image.shape) == 3 and image.shape[2] == 4: # RGBA image
49
+ image = image[:, :, :3]
50
+
51
  # Resize image to match model's expected input shape
 
52
  target_size = (224, 224) # Change this to match your model's input size
53
  image = tf.image.resize(image, target_size)
54
 
 
98
 
99
  structured_query = f"""
100
  Based on the following patient information and image analysis, provide a detailed medical analysis in JSON format:
 
101
  {query_with_image}
 
102
  Return only valid JSON with these fields:
103
  - diagnosis_details
104
  - probable_diagnoses (array)
 
132
  return json.loads(answer)
133
  except json.JSONDecodeError:
134
  try:
135
+ # Find the first occurrence of '{' and last occurrence of '}'
136
  start_idx = answer.find('{')
137
  end_idx = answer.rfind('}') + 1
138
  if start_idx != -1 and end_idx != 0:
 
142
  logger.error("Failed to parse JSON from response")
143
  raise
144
 
145
+ def format_prediction(prediction):
146
+ """Format model prediction into a standardized structure"""
147
+ try:
148
+ # Adjust this based on your model's output format
149
+ confidence = float(prediction[0][0])
150
+ return {
151
+ "prediction": "abnormal" if confidence > 0.5 else "normal",
152
+ "confidence": round(confidence * 100, 2)
153
+ }
154
+ except Exception as e:
155
+ logger.error(f"Error formatting prediction: {str(e)}")
156
+ raise
157
+
158
  # Initialize the model
159
+ try:
160
+ model = load_model()
161
+ except Exception as e:
162
+ logger.error(f"Failed to initialize model: {str(e)}")
163
+ model = None
164
 
165
  def gradio_interface(patient_info, image):
166
  try:
167
+ if model is None:
168
+ raise ValueError("Model not properly initialized")
169
+
170
  # Process image if provided
171
  image_analysis = None
172
  if image is not None:
 
177
  prediction = model.predict(processed_image)
178
 
179
  # Format prediction results
180
+ image_analysis = format_prediction(prediction)
 
 
 
 
181
 
182
  # Create chat session and submit query
183
  session_id = create_chat_session()
 
185
  json.dumps(image_analysis) if image_analysis else None)
186
 
187
  if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
188
+ raise ValueError("Invalid response structure from LLM")
189
 
190
  # Extract and clean JSON from the response
191
  json_data = extract_json_from_answer(llm_response['data']['answer'])
192
 
193
+ # Format output for better readability
194
+ return json.dumps(json_data, indent=2)
195
 
196
  except Exception as e:
197
  logger.error(f"Error in gradio_interface: {str(e)}")
198
+ return json.dumps({"error": str(e)}, indent=2)
199
 
200
  # Gradio interface
201
  iface = gr.Interface(