yashbyname commited on
Commit
aa3dc7a
·
verified ·
1 Parent(s): aff9b10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -120
app.py CHANGED
@@ -7,181 +7,131 @@ import tensorflow_hub as hub
7
  import numpy as np
8
  from PIL import Image
9
  import io
 
10
 
11
- # Set up logging
12
- logging.basicConfig(level=logging.INFO)
 
 
 
13
  logger = logging.getLogger(__name__)
14
 
15
  # API key and user ID for on-demand
16
  api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
17
  external_user_id = 'plugin-1717464304'
18
 
19
- # Load the keras model
20
  def load_model():
21
  try:
22
- # Define custom objects dictionary with batch normalization handling
 
 
 
 
 
 
 
 
23
  custom_objects = {
24
  'KerasLayer': hub.KerasLayer,
25
- 'BatchNormalization': tf.keras.layers.BatchNormalization
 
26
  }
27
 
28
- # Load model with custom object scope and proper batch norm behavior
29
- with tf.keras.utils.custom_object_scope(custom_objects):
30
- model = tf.keras.models.load_model(
31
- 'model_epoch_01.h5.keras',
32
- custom_objects=custom_objects,
33
- compile=False # Don't compile the model on load
34
- )
 
 
 
 
 
 
35
 
 
 
36
  logger.info("Model loaded successfully")
37
  return model
 
38
  except Exception as e:
39
  logger.error(f"Error loading model: {str(e)}")
 
40
  raise
41
 
42
- # Preprocess image for model
 
 
 
 
 
 
 
 
43
  def preprocess_image(image):
44
  try:
45
  # Convert to numpy array if needed
46
  if isinstance(image, Image.Image):
47
  image = np.array(image)
48
 
 
 
 
49
  # Ensure image has 3 channels (RGB)
50
  if len(image.shape) == 2: # Grayscale image
 
51
  image = np.stack((image,) * 3, axis=-1)
52
  elif len(image.shape) == 3 and image.shape[2] == 4: # RGBA image
 
53
  image = image[:, :, :3]
54
 
55
  # Resize image to match model's expected input shape
56
  target_size = (224, 224) # Change this to match your model's input size
57
  image = tf.image.resize(image, target_size)
 
58
 
59
  # Normalize pixel values
60
  image = image / 255.0
61
 
62
  # Add batch dimension
63
  image = np.expand_dims(image, axis=0)
 
64
 
65
  return image
66
- except Exception as e:
67
- logger.error(f"Error preprocessing image: {str(e)}")
68
- raise
69
-
70
- def create_chat_session():
71
- try:
72
- create_session_url = 'https://api.on-demand.io/chat/v1/sessions'
73
- create_session_headers = {
74
- 'apikey': api_key,
75
- 'Content-Type': 'application/json'
76
- }
77
- create_session_body = {
78
- "pluginIds": [],
79
- "externalUserId": external_user_id
80
- }
81
-
82
- response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body)
83
- response.raise_for_status()
84
- return response.json()['data']['id']
85
-
86
- except requests.exceptions.RequestException as e:
87
- logger.error(f"Error creating chat session: {str(e)}")
88
- raise
89
-
90
- def submit_query(session_id, query, image_analysis=None):
91
- try:
92
- submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
93
- submit_query_headers = {
94
- 'apikey': api_key,
95
- 'Content-Type': 'application/json'
96
- }
97
-
98
- # Include image analysis in the query if available
99
- query_with_image = query
100
- if image_analysis:
101
- query_with_image += f"\n\nImage Analysis Results: {image_analysis}"
102
-
103
- structured_query = f"""
104
- Based on the following patient information and image analysis, provide a detailed medical analysis in JSON format:
105
- {query_with_image}
106
- Return only valid JSON with these fields:
107
- - diagnosis_details
108
- - probable_diagnoses (array)
109
- - treatment_plans (array)
110
- - lifestyle_modifications (array)
111
- - medications (array of objects with name and dosage)
112
- - additional_tests (array)
113
- - precautions (array)
114
- - follow_up (string)
115
- - image_findings (object with prediction and confidence)
116
- """
117
 
118
- submit_query_body = {
119
- "endpointId": "predefined-openai-gpt4o",
120
- "query": structured_query,
121
- "pluginIds": ["plugin-1712327325", "plugin-1713962163"],
122
- "responseMode": "sync"
123
- }
124
-
125
- response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body)
126
- response.raise_for_status()
127
- return response.json()
128
-
129
- except requests.exceptions.RequestException as e:
130
- logger.error(f"Error submitting query: {str(e)}")
131
- raise
132
-
133
- def extract_json_from_answer(answer):
134
- """Extract and clean JSON from the LLM response"""
135
- try:
136
- return json.loads(answer)
137
- except json.JSONDecodeError:
138
- try:
139
- # Find the first occurrence of '{' and last occurrence of '}'
140
- start_idx = answer.find('{')
141
- end_idx = answer.rfind('}') + 1
142
- if start_idx != -1 and end_idx != 0:
143
- json_str = answer[start_idx:end_idx]
144
- return json.loads(json_str)
145
- except (json.JSONDecodeError, ValueError):
146
- logger.error("Failed to parse JSON from response")
147
- raise
148
-
149
- def format_prediction(prediction):
150
- """Format model prediction into a standardized structure"""
151
- try:
152
- # Adjust this based on your model's output format
153
- confidence = float(prediction[0][0])
154
- return {
155
- "prediction": "abnormal" if confidence > 0.5 else "normal",
156
- "confidence": round(confidence * 100, 2)
157
- }
158
  except Exception as e:
159
- logger.error(f"Error formatting prediction: {str(e)}")
160
  raise
161
 
162
- # Initialize the model
163
- try:
164
- model = load_model()
165
- except Exception as e:
166
- logger.error(f"Failed to initialize model: {str(e)}")
167
- model = None
168
-
169
  def gradio_interface(patient_info, image):
170
  try:
171
  if model is None:
172
- raise ValueError("Model not properly initialized")
 
 
 
 
173
 
174
  # Process image if provided
175
  image_analysis = None
176
  if image is not None:
 
177
  # Preprocess image
178
  processed_image = preprocess_image(image)
179
 
180
  # Get model prediction
 
181
  prediction = model.predict(processed_image)
 
182
 
183
  # Format prediction results
184
- image_analysis = format_prediction(prediction)
 
 
 
 
185
 
186
  # Create chat session and submit query
187
  session_id = create_chat_session()
@@ -194,13 +144,17 @@ def gradio_interface(patient_info, image):
194
  # Extract and clean JSON from the response
195
  json_data = extract_json_from_answer(llm_response['data']['answer'])
196
 
197
- # Format output for better readability
198
  return json.dumps(json_data, indent=2)
199
-
200
  except Exception as e:
201
  logger.error(f"Error in gradio_interface: {str(e)}")
202
- return json.dumps({"error": str(e)}, indent=2)
203
-
 
 
 
 
 
204
  iface = gr.Interface(
205
  fn=gradio_interface,
206
  inputs=[
@@ -213,7 +167,7 @@ iface = gr.Interface(
213
  gr.Image(
214
  label="Medical Image",
215
  type="numpy",
216
- interactive=True, # This replaces the 'optional' parameter
217
  )
218
  ],
219
  outputs=gr.Textbox(
@@ -226,7 +180,15 @@ iface = gr.Interface(
226
  )
227
 
228
  if __name__ == "__main__":
229
- iface.launch()
 
 
 
 
 
 
 
 
230
 
231
 
232
 
 
7
  import numpy as np
8
  from PIL import Image
9
  import io
10
+ import os
11
 
12
+ # Set up logging with more detailed format
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
+ )
17
  logger = logging.getLogger(__name__)
18
 
19
  # API key and user ID for on-demand
20
  api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
21
  external_user_id = 'plugin-1717464304'
22
 
 
23
  def load_model():
24
  try:
25
+ model_path = 'model_epoch_01.h5.keras'
26
+
27
+ # Check if model file exists
28
+ if not os.path.exists(model_path):
29
+ raise FileNotFoundError(f"Model file not found at {model_path}")
30
+
31
+ logger.info(f"Attempting to load model from {model_path}")
32
+
33
+ # Define custom objects dictionary
34
  custom_objects = {
35
  'KerasLayer': hub.KerasLayer,
36
+ 'BatchNormalization': tf.keras.layers.BatchNormalization,
37
+ # Add more custom objects if needed
38
  }
39
 
40
+ # Try loading with different configurations
41
+ try:
42
+ logger.info("Attempting to load model with custom objects...")
43
+ with tf.keras.utils.custom_object_scope(custom_objects):
44
+ model = tf.keras.models.load_model(model_path, compile=False)
45
+ except Exception as e:
46
+ logger.error(f"Failed to load with custom objects: {str(e)}")
47
+ logger.info("Attempting to load model without custom objects...")
48
+ model = tf.keras.models.load_model(model_path, compile=False)
49
+
50
+ # Verify model loaded correctly
51
+ if model is None:
52
+ raise ValueError("Model loading returned None")
53
 
54
+ # Print model summary for debugging
55
+ model.summary()
56
  logger.info("Model loaded successfully")
57
  return model
58
+
59
  except Exception as e:
60
  logger.error(f"Error loading model: {str(e)}")
61
+ logger.error(f"Model loading failed with exception type: {type(e)}")
62
  raise
63
 
64
+ # Initialize the model globally
65
+ try:
66
+ logger.info("Initializing model...")
67
+ model = load_model()
68
+ logger.info("Model initialization completed")
69
+ except Exception as e:
70
+ logger.error(f"Failed to initialize model: {str(e)}")
71
+ model = None
72
+
73
  def preprocess_image(image):
74
  try:
75
  # Convert to numpy array if needed
76
  if isinstance(image, Image.Image):
77
  image = np.array(image)
78
 
79
+ # Log image shape and type for debugging
80
+ logger.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
81
+
82
  # Ensure image has 3 channels (RGB)
83
  if len(image.shape) == 2: # Grayscale image
84
+ logger.info("Converting grayscale to RGB")
85
  image = np.stack((image,) * 3, axis=-1)
86
  elif len(image.shape) == 3 and image.shape[2] == 4: # RGBA image
87
+ logger.info("Converting RGBA to RGB")
88
  image = image[:, :, :3]
89
 
90
  # Resize image to match model's expected input shape
91
  target_size = (224, 224) # Change this to match your model's input size
92
  image = tf.image.resize(image, target_size)
93
+ logger.info(f"Resized image shape: {image.shape}")
94
 
95
  # Normalize pixel values
96
  image = image / 255.0
97
 
98
  # Add batch dimension
99
  image = np.expand_dims(image, axis=0)
100
+ logger.info(f"Final preprocessed image shape: {image.shape}")
101
 
102
  return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  except Exception as e:
105
+ logger.error(f"Error preprocessing image: {str(e)}")
106
  raise
107
 
 
 
 
 
 
 
 
108
  def gradio_interface(patient_info, image):
109
  try:
110
  if model is None:
111
+ logger.error("Model is not initialized")
112
+ return json.dumps({
113
+ "error": "Model initialization failed. Please check the logs for details.",
114
+ "status": "error"
115
+ }, indent=2)
116
 
117
  # Process image if provided
118
  image_analysis = None
119
  if image is not None:
120
+ logger.info("Processing uploaded image")
121
  # Preprocess image
122
  processed_image = preprocess_image(image)
123
 
124
  # Get model prediction
125
+ logger.info("Running model prediction")
126
  prediction = model.predict(processed_image)
127
+ logger.info(f"Raw prediction shape: {prediction.shape}")
128
 
129
  # Format prediction results
130
+ image_analysis = {
131
+ "prediction": float(prediction[0][0]),
132
+ "confidence": float(prediction[0][0]) * 100
133
+ }
134
+ logger.info(f"Image analysis results: {image_analysis}")
135
 
136
  # Create chat session and submit query
137
  session_id = create_chat_session()
 
144
  # Extract and clean JSON from the response
145
  json_data = extract_json_from_answer(llm_response['data']['answer'])
146
 
 
147
  return json.dumps(json_data, indent=2)
148
+
149
  except Exception as e:
150
  logger.error(f"Error in gradio_interface: {str(e)}")
151
+ return json.dumps({
152
+ "error": str(e),
153
+ "status": "error",
154
+ "details": "Check the application logs for more information"
155
+ }, indent=2)
156
+
157
+ # Gradio interface
158
  iface = gr.Interface(
159
  fn=gradio_interface,
160
  inputs=[
 
167
  gr.Image(
168
  label="Medical Image",
169
  type="numpy",
170
+ interactive=True
171
  )
172
  ],
173
  outputs=gr.Textbox(
 
180
  )
181
 
182
  if __name__ == "__main__":
183
+ # Add version information logging
184
+ logger.info(f"TensorFlow version: {tf.__version__}")
185
+ logger.info(f"TensorFlow Hub version: {hub.__version__}")
186
+ logger.info(f"Gradio version: {gr.__version__}")
187
+
188
+ iface.launch(
189
+ server_name="0.0.0.0",
190
+ debug=True
191
+ )
192
 
193
 
194