NeuroGenAI / app.py
yashbyname's picture
Update app.py
b9bef1f verified
raw
history blame
14 kB
import requests
import gradio as gr
import logging
import json
import tf_keras
import tensorflow_hub as hub
import numpy as np
from PIL import Image
import io
import os
# Set up logging with more detailed format
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# API key and user ID for on-demand
api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
external_user_id = 'plugin-1717464304'
def create_chat_session():
try:
create_session_url = 'https://api.on-demand.io/chat/v1/sessions'
create_session_headers = {
'apikey': api_key,
'Content-Type': 'application/json'
}
create_session_body = {
"pluginIds": [],
"externalUserId": external_user_id
}
response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body)
response.raise_for_status()
return response.json()['data']['id']
except requests.exceptions.RequestException as e:
logger.error(f"Error creating chat session: {str(e)}")
raise
def submit_query(session_id, query):
try:
submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
submit_query_headers = {
'apikey': api_key,
'Content-Type': 'application/json'
}
structured_query = f"""
Based on the following patient information, provide a detailed medical analysis in JSON format:
{query}
Return only valid JSON with these fields:
- diagnosis_details
- probable_diagnoses (array)
- treatment_plans (array)
- lifestyle_modifications (array)
- medications (array of objects with name and dosage)
- additional_tests (array)
- precautions (array)
- follow_up (string)
"""
submit_query_body = {
"endpointId": "predefined-openai-gpt4o",
"query": structured_query,
"pluginIds": ["plugin-1712327325", "plugin-1713962163"],
"responseMode": "sync"
}
response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"Error submitting query: {str(e)}")
raise
def extract_json_from_answer(answer, image_analysis):
"""Extract and clean JSON from the LLM response and append image analysis results."""
try:
# Try to parse the JSON answer directly
json_data = json.loads(answer)
except json.JSONDecodeError:
try:
# If that fails, try to find JSON content and parse it
start_idx = answer.find('{')
end_idx = answer.rfind('}') + 1
if start_idx != -1 and end_idx != 0:
json_str = answer[start_idx:end_idx]
json_data = json.loads(json_str)
else:
raise ValueError("Failed to locate JSON in the answer")
except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse JSON from response: {str(e)}")
raise
# Append the image analysis data
if image_analysis:
json_data["image_analysis"] = {
"prediction": image_analysis["prediction"],
"confidence": f"{image_analysis['confidence']:.2f}%" # Format confidence as percentage
}
return json_data
def load_model():
try:
model_path = 'model_epoch_01.h5.keras'
# Check if model file exists
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at {model_path}")
logger.info(f"Attempting to load model from {model_path}")
# Define custom objects dictionary
custom_objects = {
'KerasLayer': hub.KerasLayer
# Add more custom objects if needed
}
# Try loading with different configurations
try:
logger.info("Attempting to load model with custom objects...")
model = tf_keras.models.load_model(model_path, custom_objects={'KerasLayer': hub.KerasLayer})
except Exception as e:
logger.error(f"Failed to load with custom objects: {str(e)}")
logger.info("Attempting to load model without custom objects...")
model = tf_keras.models.load_model(model_path)
# Verify model loaded correctly
if model is None:
raise ValueError("Model loading returned None")
# Print model summary for debugging
model.summary()
logger.info("Model loaded successfully")
return model
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
logger.error(f"Model loading failed with exception type: {type(e)}")
raise
# Initialize the model globally
try:
logger.info("Initializing model...")
model = load_model()
logger.info("Model initialization completed")
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
model = None
def preprocess_image(image):
try:
# Log image shape and type for debugging
#logger.info(f"Input image shape: {image.}, dtype: {image.dtype}")
image = image.convert('RGB')
image = image.resize((256, 256))
image = np.array(image)
# Normalize pixel values
image = image / 255.0
# Add batch dimension
image = np.expand_dims(image, axis=0)
logger.info(f"Final preprocessed image shape: {image.shape}")
return image
except Exception as e:
logger.error(f"Error preprocessing image: {str(e)}")
raise
def gradio_interface(patient_info, image):
try:
if model is None:
logger.error("Model is not initialized")
return json.dumps({
"error": "Model initialization failed. Please check the logs for details.",
"status": "error"
}, indent=2)
classes = ["Alzheimer's", "Normal", "Stroke", "Tumor"]
# Process image if provided
image_analysis = None
if image is not None:
logger.info("Processing uploaded image")
# Preprocess image
processed_image = preprocess_image(image)
# Get model prediction
logger.info("Running model prediction")
prediction = model.predict(processed_image)
logger.info(f"Raw prediction shape: {prediction.shape}")
logger.info(f"Prediction: {prediction}")
# Format prediction results
image_analysis = {
"prediction": classes[np.argmax(prediction[0])],
"confidence": np.max(prediction[0]) * 100
}
logger.info(f"Image analysis results: {image_analysis}")
patient_info += f"Prediction based on MRI images: {image_analysis['prediction']}, Confidence: {image_analysis['confidence']}"
# Create chat session and submit query
session_id = create_chat_session()
llm_response = submit_query(session_id, patient_info)
if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
raise ValueError("Invalid response structure from LLM")
# Extract and clean JSON from the response
logger.info(f"llm_response: {llm_response}")
logger.info(f"llm_response[data]: {llm_response['data']}")
logger.info(f"llm_response[data][answer]: {llm_response['data']['answer']}")
json_data = extract_json_from_answer(llm_response['data']['answer'], image_analysis)
return json.dumps(json_data, indent=2)
except Exception as e:
logger.error(f"Error in gradio_interface: {str(e)}")
return json.dumps({
"error": str(e),
"status": "error",
"details": "Check the application logs for more information"
}, indent=2)
# Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(
label="Patient Information",
placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
lines=5,
max_lines=10
),
gr.Image(
label="Medical Image",
type="pil",
interactive=True
)
],
outputs=gr.Textbox(
label="Medical Analysis",
placeholder="JSON analysis will appear here...",
lines=15
),
title="Medical Diagnosis Assistant",
description="Enter patient information and optionally upload a medical image for analysis."
)
if __name__ == "__main__":
# Add version information logging
logger.info(f"TensorFlow Keras version: {tf_keras.__version__}")
logger.info(f"TensorFlow Hub version: {hub.__version__}")
logger.info(f"Gradio version: {gr.__version__}")
iface.launch(
server_name="0.0.0.0",
debug=True
)
# import requests
# import gradio as gr
# import logging
# import json
# # Set up logging
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
# # API key and user ID for on-demand
# api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
# external_user_id = 'plugin-1717464304'
# def create_chat_session():
# try:
# create_session_url = 'https://api.on-demand.io/chat/v1/sessions'
# create_session_headers = {
# 'apikey': api_key,
# 'Content-Type': 'application/json'
# }
# create_session_body = {
# "pluginIds": [],
# "externalUserId": external_user_id
# }
# response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body)
# response.raise_for_status()
# return response.json()['data']['id']
# except requests.exceptions.RequestException as e:
# logger.error(f"Error creating chat session: {str(e)}")
# raise
# def submit_query(session_id, query):
# try:
# submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
# submit_query_headers = {
# 'apikey': api_key,
# 'Content-Type': 'application/json'
# }
# structured_query = f"""
# Based on the following patient information, provide a detailed medical analysis in JSON format:
# {query}
# Return only valid JSON with these fields:
# - diagnosis_details
# - probable_diagnoses (array)
# - treatment_plans (array)
# - lifestyle_modifications (array)
# - medications (array of objects with name and dosage)
# - additional_tests (array)
# - precautions (array)
# - follow_up (string)
# """
# submit_query_body = {
# "endpointId": "predefined-openai-gpt4o",
# "query": structured_query,
# "pluginIds": ["plugin-1712327325", "plugin-1713962163"],
# "responseMode": "sync"
# }
# response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body)
# response.raise_for_status()
# return response.json()
# except requests.exceptions.RequestException as e:
# logger.error(f"Error submitting query: {str(e)}")
# raise
# def extract_json_from_answer(answer):
# """Extract and clean JSON from the LLM response"""
# try:
# # First try to parse the answer directly
# return json.loads(answer)
# except json.JSONDecodeError:
# try:
# # If that fails, try to find JSON content and parse it
# start_idx = answer.find('{')
# end_idx = answer.rfind('}') + 1
# if start_idx != -1 and end_idx != 0:
# json_str = answer[start_idx:end_idx]
# return json.loads(json_str)
# except (json.JSONDecodeError, ValueError):
# logger.error("Failed to parse JSON from response")
# raise
# def gradio_interface(patient_info):
# try:
# session_id = create_chat_session()
# llm_response = submit_query(session_id, patient_info)
# if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
# raise ValueError("Invalid response structure")
# # Extract and clean JSON from the response
# json_data = extract_json_from_answer(llm_response['data']['answer'])
# # Return clean JSON string without extra formatting
# return json.dumps(json_data)
# except Exception as e:
# logger.error(f"Error in gradio_interface: {str(e)}")
# return json.dumps({"error": str(e)})
# # Gradio interface
# iface = gr.Interface(
# fn=gradio_interface,
# inputs=[
# gr.Textbox(
# label="Patient Information",
# placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
# lines=5,
# max_lines=10
# )
# ],
# outputs=gr.Textbox(
# label="Medical Analysis",
# placeholder="JSON analysis will appear here...",
# lines=15
# ),
# title="Medical Diagnosis Assistant",
# description="Enter detailed patient information to receive a structured medical analysis in JSON format."
# )
# if __name__ == "__main__":
# iface.launch()