yashbyname commited on
Commit
e7f38cd
·
verified ·
1 Parent(s): 3c7b3ef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow import keras
5
+ import tensorflow_hub as hub
6
+ from PIL import Image
7
+
8
+ # Load models
9
+ model_initial = keras.models.load_model(
10
+ "models/initial_model.h5", custom_objects={'KerasLayer': hub.KerasLayer}
11
+ )
12
+ model_tumor = keras.models.load_model(
13
+ "models/model_tumor.h5", custom_objects={'KerasLayer': hub.KerasLayer}
14
+ )
15
+ model_stroke = keras.models.load_model(
16
+ "models/model_stroke.h5", custom_objects={'KerasLayer': hub.KerasLayer}
17
+ )
18
+ model_alzheimer = keras.models.load_model(
19
+ "models/model_alzheimer.h5", custom_objects={'KerasLayer': hub.KerasLayer}
20
+ )
21
+
22
+ class CombinedDiseaseModel(tf.keras.Model):
23
+ def __init__(self, model_initial, model_alzheimer, model_tumor, model_stroke):
24
+ super(CombinedDiseaseModel, self).__init__()
25
+ self.model_initial = model_initial
26
+ self.model_alzheimer = model_alzheimer
27
+ self.model_tumor = model_tumor
28
+ self.model_stroke = model_stroke
29
+ self.disease_labels = ["Alzheimer's", 'No Disease', 'Stroke', 'Tumor']
30
+
31
+ self.sub_models = {
32
+ "Alzheimer's": model_alzheimer,
33
+ 'Tumor': model_tumor,
34
+ 'Stroke': model_stroke
35
+ }
36
+
37
+ def call(self, inputs):
38
+ initial_probs = self.model_initial(inputs, training=False)
39
+ main_disease_idx = tf.argmax(initial_probs, axis=1)
40
+ main_disease = self.disease_labels[main_disease_idx[0].numpy()]
41
+ main_disease_prob = initial_probs[0, main_disease_idx[0]].numpy()
42
+
43
+ if main_disease == 'No Disease':
44
+ sub_category = "No Disease"
45
+ sub_category_prob = main_disease_prob
46
+ else:
47
+ sub_model = self.sub_models[main_disease]
48
+ sub_category_pred = sub_model(inputs, training=False)
49
+ sub_category = tf.argmax(sub_category_pred, axis=1).numpy()[0]
50
+ sub_category_prob = sub_category_pred[0, sub_category].numpy()
51
+
52
+ if main_disease == "Alzheimer's":
53
+ sub_category_label = ['Very Mild', 'Mild', 'Moderate']
54
+ elif main_disease == 'Tumor':
55
+ sub_category_label = ['Glioma', 'Meningioma', 'Pituitary']
56
+ elif main_disease == 'Stroke':
57
+ sub_category_label = ['Ischemic', 'Hemorrhagic']
58
+
59
+ sub_category = sub_category_label[sub_category]
60
+
61
+ return f"The MRI image shows {main_disease} with a probability of {main_disease_prob*100:.2f}%.\nThe subcategory of {main_disease} is {sub_category} with a probability of {sub_category_prob*100:.2f}%."
62
+
63
+
64
+ # Initialize the combined model
65
+ cnn_model = CombinedDiseaseModel(
66
+ model_initial=model_initial,
67
+ model_alzheimer=model_alzheimer,
68
+ model_tumor=model_tumor,
69
+ model_stroke=model_stroke
70
+ )
71
+
72
+
73
+ def process_image(image):
74
+ image = image.resize((256, 256))
75
+ image.convert("RGB")
76
+ image_array = np.array(image) / 255.0
77
+ image_array = np.expand_dims(image_array, axis=0)
78
+ predictions = cnn_model(image_array)
79
+ return predictions
80
+
81
+
82
+ def gradio_interface(patient_info, query_type, image):
83
+ if image is not None:
84
+ image_response = process_image(image)
85
+ response = f"Patient Info: {patient_info}\nQuery Type: {query_type}\n{image_response}"
86
+ return response
87
+ else:
88
+ return "Please upload an image."
89
+
90
+
91
+ # Create Gradio app
92
+ iface = gr.Interface(
93
+ fn=gradio_interface,
94
+ inputs=[
95
+ gr.Textbox(
96
+ label="Patient Information",
97
+ placeholder="Enter patient details here...",
98
+ lines=5,
99
+ max_lines=10
100
+ ),
101
+ gr.Textbox(
102
+ label="Query Type"
103
+ ),
104
+ gr.Image(
105
+ type="pil",
106
+ label="Upload an Image",
107
+ )
108
+ ],
109
+ outputs=gr.Textbox(label="Response", placeholder="The response will appear here..."),
110
+ title="Medical Diagnosis with MRI",
111
+ description="Upload MRI images and provide patient information for diagnosis.",
112
+ )
113
+
114
+ iface.launch()