Jiranuwat's picture
Update app.py
f6a1753 verified
import gradio as gr
import torch
from utils import *
import os
model = torch.load('AlzheimerModelCPU.pth')
model.eval()
def reset():
return None, None, None, None, None
with gr.Blocks() as demo:
gr.HTML("""
<h1 style="text-align: center; font-size: 50px;">
Alzheimer Detection with Explainable AI (XAI)
</h1>
<p style="text-align: center;">
Early Detection of Alzheimer's Disease: A Deep Learning Approach for Accurate Diagnosis.
</p>
<h3 style="text-align: left;">
To use the demo, please follow the steps below.
</h3>
<ul>
<li style="text-align: left; padding: 0px 0px 0px 30px;">
If you want to try examples, click one of <span style="font-weight: bold;">Examples</span> images below. Then, click
<span style="font-weight: bold;">Submit</span>.
</li>
<li style="text-align: left; padding: 0px 0px 0px 30px;">
if you don't want to try examples, upload an image and click <span style="font-weight: bold;">Submit</span>.
</li>
<li style="text-align: left; padding: 0px 0px 0px 30px;">
You can adjust the <span style="font-weight: bold;">Target</span> for your desire visualization and
<span style="font-weight: bold;">Plot Type</span> between <span style="font-weight: bold;">withmask</span> and <span style="font-weight: bold;">withoutmask</span>
to plot original images with or without Grad-CAM.
</li>
<li style="text-align: left; padding: 0px 0px 0px 30px;">
If you want to reset all components, click <span style="font-weight: bold;">Reset All Components</span> button.
</li>
</ul>
"""
)
with gr.Row():
with gr.Column():
image_area = gr.Image(sources=["upload"], type="pil", scale=2, label="Upload Image")
with gr.Row():
choosen_plottype =gr.Radio(choices=["withmask", "withoutmask"], value="withmask", label="Plot Type", scale=1, interactive=True)
choosen_target = gr.Slider(minimum=0, maximum=3, step=1, value=2, label="Classifier OutputTarget", scale=1, interactive=True)
submit_btn = gr.Button("Submit", variant='primary', scale=1)
with gr.Column():
gr.HTML("""
<h2 style="text-align: center;">
</h2>
<h2 style="text-align: center;">
Prediction and Grad-CAM
</h2>
<p style="text-align: center;">
Alzheimer detection result with Grad-CAM for 3 images. Original, White Matter and Gray Matter image.
</p>
"""
)
text_area = gr.Textbox(label="Prediction Result")
plotarea_original = gr.Plot(label="Original image")
plotarea_white = gr.Plot(label="White Matter image")
plotarea_gray = gr.Plot(label="Gray Matter image")
reset_btn = gr.Button("Reset All Components", variant='stop', scale=1)
gr.HTML("""
<h2 style="text-align: left;">
Examples
</h2>
<p style="text-align: left;">
You can select 1 image from the examples and click "Submit".
</p>
<p style="text-align: left;">
These images represent 4 different stages of Alzheimer's disease in the following order:<span style="font-weight: bold;">
Non Demented, Very Mild Demented, Mild Demented and Moderate Demented.</span>
</p>
"""
)
examples = gr.Examples(examples=[ os.path.join(os.getcwd(), "non.jpg"),
os.path.join(os.getcwd(),"verymild.jpg"),
os.path.join(os.getcwd(), "mild.jpg"),
os.path.join(os.getcwd(), "moderate.jpg")],
inputs=image_area,
outputs=image_area,
label="Examples",
)
submit_btn.click(lambda x, target, plot_type: predict_and_gradcam(x, model=model, target=target, plot_type=plot_type), inputs=[image_area, choosen_target, choosen_plottype], outputs=[text_area, plotarea_original, plotarea_white, plotarea_gray])
reset_btn.click(reset, outputs=[image_area, text_area, plotarea_original, plotarea_white, plotarea_gray])
demo.launch()