Spaces:
Sleeping
Sleeping
import requests | |
import gradio as gr | |
import logging | |
import json | |
import tf_keras | |
import tensorflow_hub as hub | |
import numpy as np | |
from PIL import Image | |
import io | |
import os | |
# Set up logging with more detailed format | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# API key and user ID for on-demand | |
api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3' | |
external_user_id = 'plugin-1717464304' | |
def create_chat_session(): | |
try: | |
create_session_url = 'https://api.on-demand.io/chat/v1/sessions' | |
create_session_headers = { | |
'apikey': api_key, | |
'Content-Type': 'application/json' | |
} | |
create_session_body = { | |
"pluginIds": [], | |
"externalUserId": external_user_id | |
} | |
response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body) | |
response.raise_for_status() | |
return response.json()['data']['id'] | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Error creating chat session: {str(e)}") | |
raise | |
def submit_query(session_id, query): | |
try: | |
submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query' | |
submit_query_headers = { | |
'apikey': api_key, | |
'Content-Type': 'application/json' | |
} | |
structured_query = f""" | |
Based on the following patient information, provide a detailed medical analysis in JSON format: | |
{query} | |
Return only valid JSON with these fields: | |
- diagnosis_details | |
- probable_diagnoses (array) | |
- treatment_plans (array) | |
- lifestyle_modifications (array) | |
- medications (array of objects with name and dosage) | |
- additional_tests (array) | |
- precautions (array) | |
- follow_up (string) | |
""" | |
submit_query_body = { | |
"endpointId": "predefined-openai-gpt4o", | |
"query": structured_query, | |
"pluginIds": ["plugin-1712327325", "plugin-1713962163"], | |
"responseMode": "sync" | |
} | |
response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Error submitting query: {str(e)}") | |
raise | |
def extract_json_from_answer(answer, image_analysis): | |
"""Extract and clean JSON from the LLM response and append image analysis results.""" | |
try: | |
# Try to parse the JSON answer directly | |
json_data = json.loads(answer) | |
except json.JSONDecodeError: | |
try: | |
# If that fails, try to find JSON content and parse it | |
start_idx = answer.find('{') | |
end_idx = answer.rfind('}') + 1 | |
if start_idx != -1 and end_idx != 0: | |
json_str = answer[start_idx:end_idx] | |
json_data = json.loads(json_str) | |
else: | |
raise ValueError("Failed to locate JSON in the answer") | |
except (json.JSONDecodeError, ValueError) as e: | |
logger.error(f"Failed to parse JSON from response: {str(e)}") | |
raise | |
# Append the image analysis data | |
if image_analysis: | |
json_data["image_analysis"] = { | |
"prediction": image_analysis["prediction"], | |
"confidence": f"{image_analysis['confidence']:.2f}%" # Format confidence as percentage | |
} | |
return json_data | |
def load_model(): | |
try: | |
model_path = 'model_epoch_01.h5.keras' | |
# Check if model file exists | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"Model file not found at {model_path}") | |
logger.info(f"Attempting to load model from {model_path}") | |
# Define custom objects dictionary | |
custom_objects = { | |
'KerasLayer': hub.KerasLayer | |
# Add more custom objects if needed | |
} | |
# Try loading with different configurations | |
try: | |
logger.info("Attempting to load model with custom objects...") | |
model = tf_keras.models.load_model(model_path, custom_objects={'KerasLayer': hub.KerasLayer}) | |
except Exception as e: | |
logger.error(f"Failed to load with custom objects: {str(e)}") | |
logger.info("Attempting to load model without custom objects...") | |
model = tf_keras.models.load_model(model_path) | |
# Verify model loaded correctly | |
if model is None: | |
raise ValueError("Model loading returned None") | |
# Print model summary for debugging | |
model.summary() | |
logger.info("Model loaded successfully") | |
return model | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
logger.error(f"Model loading failed with exception type: {type(e)}") | |
raise | |
# Initialize the model globally | |
try: | |
logger.info("Initializing model...") | |
model = load_model() | |
logger.info("Model initialization completed") | |
except Exception as e: | |
logger.error(f"Failed to initialize model: {str(e)}") | |
model = None | |
def preprocess_image(image): | |
try: | |
# Log image shape and type for debugging | |
#logger.info(f"Input image shape: {image.}, dtype: {image.dtype}") | |
image = image.convert('RGB') | |
image = image.resize((256, 256)) | |
image = np.array(image) | |
# Normalize pixel values | |
image = image / 255.0 | |
# Add batch dimension | |
image = np.expand_dims(image, axis=0) | |
logger.info(f"Final preprocessed image shape: {image.shape}") | |
return image | |
except Exception as e: | |
logger.error(f"Error preprocessing image: {str(e)}") | |
raise | |
def gradio_interface(patient_info, image): | |
try: | |
if model is None: | |
logger.error("Model is not initialized") | |
return json.dumps({ | |
"error": "Model initialization failed. Please check the logs for details.", | |
"status": "error" | |
}, indent=2) | |
classes = ["Alzheimer's", "Normal", "Stroke", "Tumor"] | |
# Process image if provided | |
image_analysis = None | |
if image is not None: | |
logger.info("Processing uploaded image") | |
# Preprocess image | |
processed_image = preprocess_image(image) | |
# Get model prediction | |
logger.info("Running model prediction") | |
prediction = model.predict(processed_image) | |
logger.info(f"Raw prediction shape: {prediction.shape}") | |
logger.info(f"Prediction: {prediction}") | |
# Format prediction results | |
image_analysis = { | |
"prediction": classes[np.argmax(prediction[0])], | |
"confidence": np.max(prediction[0]) * 100 | |
} | |
logger.info(f"Image analysis results: {image_analysis}") | |
patient_info += f"Prediction based on MRI images: {image_analysis['prediction']}, Confidence: {image_analysis['confidence']}" | |
# Create chat session and submit query | |
session_id = create_chat_session() | |
llm_response = submit_query(session_id, patient_info) | |
if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']: | |
raise ValueError("Invalid response structure from LLM") | |
# Extract and clean JSON from the response | |
logger.info(f"llm_response: {llm_response}") | |
logger.info(f"llm_response[data]: {llm_response['data']}") | |
logger.info(f"llm_response[data][answer]: {llm_response['data']['answer']}") | |
json_data = extract_json_from_answer(llm_response['data']['answer'], image_analysis) | |
return json.dumps(json_data, indent=2) | |
except Exception as e: | |
logger.error(f"Error in gradio_interface: {str(e)}") | |
return json.dumps({ | |
"error": str(e), | |
"status": "error", | |
"details": "Check the application logs for more information" | |
}, indent=2) | |
# Gradio interface | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox( | |
label="Patient Information", | |
placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...", | |
lines=5, | |
max_lines=10 | |
), | |
gr.Image( | |
label="Medical Image", | |
type="pil", | |
interactive=True | |
) | |
], | |
outputs=gr.Textbox( | |
label="Medical Analysis", | |
placeholder="JSON analysis will appear here...", | |
lines=15 | |
), | |
title="Medical Diagnosis Assistant", | |
description="Enter patient information and optionally upload a medical image for analysis." | |
) | |
if __name__ == "__main__": | |
# Add version information logging | |
logger.info(f"TensorFlow Keras version: {tf_keras.__version__}") | |
logger.info(f"TensorFlow Hub version: {hub.__version__}") | |
logger.info(f"Gradio version: {gr.__version__}") | |
iface.launch( | |
server_name="0.0.0.0", | |
debug=True | |
) | |
# import requests | |
# import gradio as gr | |
# import logging | |
# import json | |
# # Set up logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# # API key and user ID for on-demand | |
# api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3' | |
# external_user_id = 'plugin-1717464304' | |
# def create_chat_session(): | |
# try: | |
# create_session_url = 'https://api.on-demand.io/chat/v1/sessions' | |
# create_session_headers = { | |
# 'apikey': api_key, | |
# 'Content-Type': 'application/json' | |
# } | |
# create_session_body = { | |
# "pluginIds": [], | |
# "externalUserId": external_user_id | |
# } | |
# response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body) | |
# response.raise_for_status() | |
# return response.json()['data']['id'] | |
# except requests.exceptions.RequestException as e: | |
# logger.error(f"Error creating chat session: {str(e)}") | |
# raise | |
# def submit_query(session_id, query): | |
# try: | |
# submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query' | |
# submit_query_headers = { | |
# 'apikey': api_key, | |
# 'Content-Type': 'application/json' | |
# } | |
# structured_query = f""" | |
# Based on the following patient information, provide a detailed medical analysis in JSON format: | |
# {query} | |
# Return only valid JSON with these fields: | |
# - diagnosis_details | |
# - probable_diagnoses (array) | |
# - treatment_plans (array) | |
# - lifestyle_modifications (array) | |
# - medications (array of objects with name and dosage) | |
# - additional_tests (array) | |
# - precautions (array) | |
# - follow_up (string) | |
# """ | |
# submit_query_body = { | |
# "endpointId": "predefined-openai-gpt4o", | |
# "query": structured_query, | |
# "pluginIds": ["plugin-1712327325", "plugin-1713962163"], | |
# "responseMode": "sync" | |
# } | |
# response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body) | |
# response.raise_for_status() | |
# return response.json() | |
# except requests.exceptions.RequestException as e: | |
# logger.error(f"Error submitting query: {str(e)}") | |
# raise | |
# def extract_json_from_answer(answer): | |
# """Extract and clean JSON from the LLM response""" | |
# try: | |
# # First try to parse the answer directly | |
# return json.loads(answer) | |
# except json.JSONDecodeError: | |
# try: | |
# # If that fails, try to find JSON content and parse it | |
# start_idx = answer.find('{') | |
# end_idx = answer.rfind('}') + 1 | |
# if start_idx != -1 and end_idx != 0: | |
# json_str = answer[start_idx:end_idx] | |
# return json.loads(json_str) | |
# except (json.JSONDecodeError, ValueError): | |
# logger.error("Failed to parse JSON from response") | |
# raise | |
# def gradio_interface(patient_info): | |
# try: | |
# session_id = create_chat_session() | |
# llm_response = submit_query(session_id, patient_info) | |
# if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']: | |
# raise ValueError("Invalid response structure") | |
# # Extract and clean JSON from the response | |
# json_data = extract_json_from_answer(llm_response['data']['answer']) | |
# # Return clean JSON string without extra formatting | |
# return json.dumps(json_data) | |
# except Exception as e: | |
# logger.error(f"Error in gradio_interface: {str(e)}") | |
# return json.dumps({"error": str(e)}) | |
# # Gradio interface | |
# iface = gr.Interface( | |
# fn=gradio_interface, | |
# inputs=[ | |
# gr.Textbox( | |
# label="Patient Information", | |
# placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...", | |
# lines=5, | |
# max_lines=10 | |
# ) | |
# ], | |
# outputs=gr.Textbox( | |
# label="Medical Analysis", | |
# placeholder="JSON analysis will appear here...", | |
# lines=15 | |
# ), | |
# title="Medical Diagnosis Assistant", | |
# description="Enter detailed patient information to receive a structured medical analysis in JSON format." | |
# ) | |
# if __name__ == "__main__": | |
# iface.launch() |