Spaces:
Sleeping
Sleeping
Upload 10 files
Browse files- AlzheimerModelCPU.pth +3 -0
- AlzheimerTriMatterNet.py +24 -0
- Resnet18.py +100 -0
- app.py +93 -0
- mild.jpg +0 -0
- moderate.jpg +0 -0
- non.jpg +0 -0
- requirements.txt +7 -0
- utils.py +182 -0
- verymild.jpg +0 -0
AlzheimerModelCPU.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2241af42d3188aea27dcc36259f72e57a6ad06b555e1687c3f298e52c3128ce8
|
3 |
+
size 134441434
|
AlzheimerTriMatterNet.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torchvision.models as models
|
3 |
+
from Resnet18 import *
|
4 |
+
|
5 |
+
class AlzheimerTriMatterNet(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(AlzheimerTriMatterNet, self).__init__()
|
8 |
+
self.numclass = 4
|
9 |
+
self.whitematter_resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
|
10 |
+
self.graymatter_resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
|
11 |
+
self.resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
|
12 |
+
self.global_classification_head = nn.Sequential(
|
13 |
+
nn.Linear(512*3,self.numclass),
|
14 |
+
nn.Softmax(dim=1),
|
15 |
+
)
|
16 |
+
|
17 |
+
def forward(self, whitematter, graymatter, original):
|
18 |
+
white_output = self.whitematter_resnet18_model(whitematter)
|
19 |
+
gray_output = self.graymatter_resnet18_model(graymatter)
|
20 |
+
origin_output = self.resnet18_model(original)
|
21 |
+
combined_tensor = torch.cat(( white_output, gray_output, origin_output), dim=1)
|
22 |
+
output = self.global_classification_head(combined_tensor)
|
23 |
+
return output
|
24 |
+
|
Resnet18.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Code from : https://debuggercafe.com/implementing-resnet18-in-pytorch-from-scratch/"""
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
from torchvision.ops import RoIPool
|
5 |
+
from torch import Tensor
|
6 |
+
from typing import Type
|
7 |
+
|
8 |
+
class BasicBlock(nn.Module):
|
9 |
+
def __init__(self, in_channels: int,out_channels: int,stride: int = 1,expansion: int = 1,downsample: nn.Module = None) -> None:
|
10 |
+
super(BasicBlock, self).__init__()
|
11 |
+
# Multiplicative factor for the subsequent conv2d layer's output channels.
|
12 |
+
# It is 1 for ResNet18 and ResNet34.
|
13 |
+
self.expansion = expansion
|
14 |
+
self.downsample = downsample
|
15 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1,bias=False)
|
16 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
17 |
+
self.relu = nn.ReLU(inplace=True)
|
18 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=3, padding=1,bias=False)
|
19 |
+
self.bn2 = nn.BatchNorm2d(out_channels*self.expansion)
|
20 |
+
|
21 |
+
def forward(self, x: Tensor) -> Tensor:
|
22 |
+
identity = x
|
23 |
+
out = self.conv1(x)
|
24 |
+
out = self.bn1(out)
|
25 |
+
out = self.relu(out)
|
26 |
+
out = self.conv2(out)
|
27 |
+
out = self.bn2(out)
|
28 |
+
if self.downsample is not None:
|
29 |
+
identity = self.downsample(x)
|
30 |
+
out += identity
|
31 |
+
out = self.relu(out)
|
32 |
+
return out
|
33 |
+
|
34 |
+
class ResNet18(nn.Module):
|
35 |
+
def __init__(self, img_channels: int,num_layers: int,block: Type[BasicBlock],num_classes: int = 1000) -> None:
|
36 |
+
super(ResNet18, self).__init__()
|
37 |
+
if num_layers == 18:
|
38 |
+
# The following `layers` list defines the number of `BasicBlock`
|
39 |
+
# to use to build the network and how many basic blocks to stack
|
40 |
+
# together.
|
41 |
+
layers = [2, 2, 2, 2]
|
42 |
+
self.expansion = 1
|
43 |
+
|
44 |
+
self.in_channels = 64
|
45 |
+
# All ResNets (18 to 152) contain a Conv2d => BN => ReLU for the first
|
46 |
+
# three layers. Here, kernel size is 7.
|
47 |
+
self.conv1 = nn.Conv2d(in_channels=img_channels,out_channels=self.in_channels,kernel_size=7, stride=2,padding=3,bias=False)
|
48 |
+
self.bn1 = nn.BatchNorm2d(self.in_channels)
|
49 |
+
self.relu = nn.ReLU(inplace=True)
|
50 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
51 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
52 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
53 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
54 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
55 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
56 |
+
self.fc = nn.Linear(512*self.expansion, num_classes)
|
57 |
+
|
58 |
+
def _make_layer(self, block: Type[BasicBlock],out_channels: int,blocks: int,stride: int = 1) -> nn.Sequential:
|
59 |
+
downsample = None
|
60 |
+
if stride != 1:
|
61 |
+
"""
|
62 |
+
This should pass from `layer2` to `layer4` or
|
63 |
+
when building ResNets50 and above. Section 3.3 of the paper
|
64 |
+
Deep Residual Learning for Image Recognition
|
65 |
+
(https://arxiv.org/pdf/1512.03385v1.pdf).
|
66 |
+
"""
|
67 |
+
downsample = nn.Sequential(
|
68 |
+
nn.Conv2d(self.in_channels, out_channels*self.expansion,kernel_size=1,stride=stride,bias=False),
|
69 |
+
nn.BatchNorm2d(out_channels * self.expansion),
|
70 |
+
)
|
71 |
+
layers = []
|
72 |
+
layers.append(
|
73 |
+
block(
|
74 |
+
self.in_channels, out_channels, stride, self.expansion, downsample
|
75 |
+
)
|
76 |
+
)
|
77 |
+
self.in_channels = out_channels * self.expansion
|
78 |
+
for i in range(1, blocks):
|
79 |
+
layers.append(block(
|
80 |
+
self.in_channels,
|
81 |
+
out_channels,
|
82 |
+
expansion=self.expansion
|
83 |
+
))
|
84 |
+
return nn.Sequential(*layers)
|
85 |
+
def forward(self, x: Tensor) -> Tensor:
|
86 |
+
x = self.conv1(x)
|
87 |
+
x = self.bn1(x)
|
88 |
+
x = self.relu(x)
|
89 |
+
x = self.maxpool(x)
|
90 |
+
x = self.layer1(x)
|
91 |
+
x = self.layer2(x)
|
92 |
+
x = self.layer3(x)
|
93 |
+
x = self.layer4(x)
|
94 |
+
# The spatial dimension of the final layer's feature
|
95 |
+
# map should be (7, 7) for all ResNets.
|
96 |
+
#print('Dimensions of the last convolutional feature map: ', x.shape)
|
97 |
+
x = self.avgpool(x)
|
98 |
+
x = torch.flatten(x, 1)
|
99 |
+
#x = self.fc(x)
|
100 |
+
return x
|
app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from utils import *
|
4 |
+
import os
|
5 |
+
|
6 |
+
model = torch.load('AlzheimerModelCPU.pth')
|
7 |
+
model.eval()
|
8 |
+
|
9 |
+
def reset():
|
10 |
+
return None, None, None, None, None
|
11 |
+
|
12 |
+
with gr.Blocks() as demo:
|
13 |
+
gr.HTML("""
|
14 |
+
<h1 style="text-align: center; font-size: 50px;">
|
15 |
+
Alzheimer Detection
|
16 |
+
</h1>
|
17 |
+
<p style="text-align: center;">
|
18 |
+
Early Detection of Alzheimer's Disease: A Deep Learning Approach for Accurate Diagnosis.
|
19 |
+
</p>
|
20 |
+
<h3 style="text-align: left;">
|
21 |
+
To use the demo, please follow the steps below.
|
22 |
+
</h3>
|
23 |
+
<ul>
|
24 |
+
<li style="text-align: left; padding: 0px 0px 0px 30px;">
|
25 |
+
If you want to try examples, click one of <span style="font-weight: bold;">Examples</span> images below. Then, click
|
26 |
+
<span style="font-weight: bold;">Submit</span>.
|
27 |
+
</li>
|
28 |
+
<li style="text-align: left; padding: 0px 0px 0px 30px;">
|
29 |
+
if you don't want to try examples, upload an image and click <span style="font-weight: bold;">Submit</span>.
|
30 |
+
</li>
|
31 |
+
<li style="text-align: left; padding: 0px 0px 0px 30px;">
|
32 |
+
You can adjust the <span style="font-weight: bold;">Target</span> for your desire visualization and
|
33 |
+
<span style="font-weight: bold;">Plot Type</span> between <span style="font-weight: bold;">withmask</span> and <span style="font-weight: bold;">withoutmask</span>
|
34 |
+
to plot original images with or without Grad-CAM.
|
35 |
+
</li>
|
36 |
+
<li style="text-align: left; padding: 0px 0px 0px 30px;">
|
37 |
+
If you want to reset all components, click <span style="font-weight: bold;">Reset All Components</span> button.
|
38 |
+
</li>
|
39 |
+
</ul>
|
40 |
+
"""
|
41 |
+
)
|
42 |
+
with gr.Row():
|
43 |
+
|
44 |
+
with gr.Column():
|
45 |
+
image_area = gr.Image(sources=["upload"], type="pil", scale=2, label="Upload Image")
|
46 |
+
|
47 |
+
with gr.Row():
|
48 |
+
choosen_plottype =gr.Radio(choices=["withmask", "withoutmask"], value="withmask", label="Plot Type", scale=1, interactive=True)
|
49 |
+
choosen_target = gr.Slider(minimum=0, maximum=200, step=1, value=100, label="Classifier OutputTarget", scale=1, interactive=True)
|
50 |
+
submit_btn = gr.Button("Submit", variant='primary', scale=1)
|
51 |
+
|
52 |
+
with gr.Column():
|
53 |
+
gr.HTML("""
|
54 |
+
<h2 style="text-align: center;">
|
55 |
+
|
56 |
+
</h2>
|
57 |
+
<h2 style="text-align: center;">
|
58 |
+
Output and Prediction
|
59 |
+
</h2>
|
60 |
+
<p style="text-align: center;">
|
61 |
+
Grad-CAM for 3 images. Original, White Matter, Gray Matter image.
|
62 |
+
</p>
|
63 |
+
"""
|
64 |
+
)
|
65 |
+
text_area = gr.Textbox(label="Prediction Result")
|
66 |
+
plotarea_original = gr.Plot(label="Original image")
|
67 |
+
plotarea_white = gr.Plot(label="White Matter image")
|
68 |
+
plotarea_gray = gr.Plot(label="Gray Matter image")
|
69 |
+
|
70 |
+
reset_btn = gr.Button("Reset All Components", variant='stop', scale=1)
|
71 |
+
|
72 |
+
gr.HTML("""
|
73 |
+
<h2 style="text-align: left;">
|
74 |
+
Examples
|
75 |
+
</h2>
|
76 |
+
<p style="text-align: left;">
|
77 |
+
You can select 1 image from the examples and click "Submit".
|
78 |
+
</p>
|
79 |
+
"""
|
80 |
+
)
|
81 |
+
examples = gr.Examples(examples=[ os.path.join(os.getcwd(), "examples\\non.jpg"),
|
82 |
+
os.path.join(os.getcwd(),"examples\\verymild.jpg"),
|
83 |
+
os.path.join(os.getcwd(), "examples\\mild.jpg"),
|
84 |
+
os.path.join(os.getcwd(), "examples\\moderate.jpg")],
|
85 |
+
inputs=image_area,
|
86 |
+
outputs=image_area,
|
87 |
+
label="Examples",
|
88 |
+
)
|
89 |
+
|
90 |
+
submit_btn.click(lambda x, target, plot_type: predict_and_gradcam(x, model=model, target=target, plot_type=plot_type), inputs=[image_area, choosen_target, choosen_plottype], outputs=[text_area, plotarea_original, plotarea_white, plotarea_gray])
|
91 |
+
reset_btn.click(reset, outputs=[image_area, text_area, plotarea_original, plotarea_white, plotarea_gray])
|
92 |
+
|
93 |
+
demo.launch()
|
mild.jpg
ADDED
![]() |
moderate.jpg
ADDED
![]() |
non.jpg
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
matplotlib
|
5 |
+
opencv-python
|
6 |
+
numpy
|
7 |
+
grad-cam
|
utils.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from pytorch_grad_cam import GradCAM
|
7 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
8 |
+
|
9 |
+
class split_white_and_gray():
|
10 |
+
def __init__(self,threshold=120) -> None:
|
11 |
+
"""
|
12 |
+
Initialize the class with a threshold value.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
threshold (int, optional): The threshold value to be set. Defaults to 120.
|
16 |
+
"""
|
17 |
+
self.threshold = threshold
|
18 |
+
|
19 |
+
def __call__(self,tensor):
|
20 |
+
"""
|
21 |
+
Apply thresholding to the input tensor and return the white matter, gray matter, and the original tensor.
|
22 |
+
|
23 |
+
Parameters:
|
24 |
+
tensor (torch.Tensor): The input tensor to be thresholded.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
torch.Tensor: The thresholded white matter.
|
28 |
+
torch.Tensor: The thresholded gray matter.
|
29 |
+
torch.Tensor: The original input tensor.
|
30 |
+
"""
|
31 |
+
tensor = (tensor*255).to(torch.int64)
|
32 |
+
|
33 |
+
# Apply thresholding
|
34 |
+
white_matter = torch.where(tensor >= self.threshold,tensor,0)
|
35 |
+
white_matter = (white_matter/255).to(torch.float64)
|
36 |
+
gray_matter = torch.where(tensor < self.threshold,tensor,0)
|
37 |
+
gray_matter = (gray_matter/255).to(torch.float64)
|
38 |
+
tensor = (tensor/255).to(torch.float64)
|
39 |
+
|
40 |
+
return white_matter, gray_matter,tensor
|
41 |
+
|
42 |
+
def showcam_withoutmask(original_image, grayscale_cam, image_title='Original Image'):
|
43 |
+
"""This function applies the CAM mask to the original image and returns the Matplotlib Figure object.
|
44 |
+
|
45 |
+
:param original_image: The original image tensor in PyTorch format.
|
46 |
+
:param grayscale_cam: The CAM mask tensor in PyTorch format.
|
47 |
+
|
48 |
+
:return: Matplotlib Figure object.
|
49 |
+
"""
|
50 |
+
# Assuming you have two tensors: 'original_image' and 'cam_mask'
|
51 |
+
# Make sure both tensors are on the CPU
|
52 |
+
original_image = torch.squeeze(original_image).cpu() # torch.Size([3, 150, 150])
|
53 |
+
cam_mask = grayscale_cam.cpu() # torch.Size([1, 150, 150])
|
54 |
+
|
55 |
+
# Convert the tensors to NumPy arrays
|
56 |
+
original_image_np = original_image.numpy()
|
57 |
+
cam_mask_np = cam_mask.numpy()
|
58 |
+
|
59 |
+
# Apply the mask to the original image
|
60 |
+
masked_image = original_image_np * cam_mask_np
|
61 |
+
|
62 |
+
# Normalize the masked_image
|
63 |
+
masked_image_norm = (masked_image - np.min(masked_image)) / (np.max(masked_image) - np.min(masked_image))
|
64 |
+
|
65 |
+
# Create Matplotlib Figure
|
66 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
67 |
+
|
68 |
+
# Plot the original image
|
69 |
+
axes[0].imshow(original_image_np.transpose(1, 2, 0)) # Assuming your original image is in (C, H, W) format
|
70 |
+
axes[0].set_title(image_title)
|
71 |
+
|
72 |
+
# Plot the CAM mask
|
73 |
+
axes[1].imshow(cam_mask_np[0], cmap='jet') # Assuming your mask is grayscale
|
74 |
+
axes[1].set_title('CAM Mask')
|
75 |
+
|
76 |
+
# Plot the overlay (normalized)
|
77 |
+
axes[2].imshow(masked_image_norm.transpose(1, 2, 0)) # Assuming your original image is in (C, H, W) format
|
78 |
+
axes[2].set_title('Overlay (Normalized)')
|
79 |
+
|
80 |
+
return fig
|
81 |
+
|
82 |
+
def showcam_withmask(img_tensor: torch.Tensor,
|
83 |
+
mask_tensor: torch.Tensor,
|
84 |
+
use_rgb: bool = False,
|
85 |
+
colormap: int = cv2.COLORMAP_JET,
|
86 |
+
image_weight: float = 0.5,
|
87 |
+
image_title: str = 'Original Image') -> plt.Figure:
|
88 |
+
""" This function overlays the CAM mask on the image as a heatmap and returns the Figure object.
|
89 |
+
By default, the heatmap is in BGR format.
|
90 |
+
|
91 |
+
:param img_tensor: The base image tensor in PyTorch format.
|
92 |
+
:param mask_tensor: The CAM mask tensor in PyTorch format.
|
93 |
+
:param use_rgb: Whether to use an RGB or BGR heatmap; set to True if 'img_tensor' is in RGB format.
|
94 |
+
:param colormap: The OpenCV colormap to be used.
|
95 |
+
:param image_weight: The final result is image_weight * img + (1-image_weight) * mask.
|
96 |
+
|
97 |
+
:return: Matplotlib Figure object.
|
98 |
+
"""
|
99 |
+
# Convert PyTorch tensors to NumPy arrays
|
100 |
+
img = img_tensor.cpu().numpy().transpose(1, 2, 0)
|
101 |
+
mask = mask_tensor.cpu().numpy()
|
102 |
+
|
103 |
+
# Convert the mask to a single-channel image
|
104 |
+
mask_single_channel = np.uint8(255 * mask[0])
|
105 |
+
|
106 |
+
heatmap = cv2.applyColorMap(mask_single_channel, colormap)
|
107 |
+
|
108 |
+
if use_rgb:
|
109 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
110 |
+
|
111 |
+
heatmap = np.float32(heatmap) / 255
|
112 |
+
|
113 |
+
if np.max(img) > 1:
|
114 |
+
raise Exception("The input image should be in the range [0, 1]")
|
115 |
+
|
116 |
+
if image_weight < 0 or image_weight > 1:
|
117 |
+
raise Exception(f"image_weight should be in the range [0, 1]. Got: {image_weight}")
|
118 |
+
|
119 |
+
cam = (1 - image_weight) * heatmap + image_weight * img
|
120 |
+
cam = cam / np.max(cam)
|
121 |
+
|
122 |
+
# Create Matplotlib Figure
|
123 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
124 |
+
|
125 |
+
# Plot the original image
|
126 |
+
axes[0].imshow(img)
|
127 |
+
axes[0].set_title(image_title)
|
128 |
+
|
129 |
+
# Plot the CAM mask
|
130 |
+
axes[1].imshow(mask[0], cmap='jet')
|
131 |
+
axes[1].set_title('CAM Mask')
|
132 |
+
|
133 |
+
# Plot the overlay
|
134 |
+
axes[2].imshow(cam)
|
135 |
+
axes[2].set_title('Overlay')
|
136 |
+
|
137 |
+
return fig
|
138 |
+
|
139 |
+
def predict_and_gradcam(pil_image, model, target=100, plot_type='withmask'):
|
140 |
+
transform = transforms.Compose([
|
141 |
+
transforms.Resize((150, 150)),
|
142 |
+
transforms.Grayscale(num_output_channels=3),
|
143 |
+
transforms.ToTensor(),
|
144 |
+
split_white_and_gray(120),
|
145 |
+
])
|
146 |
+
white_matter_tensor, gray_matter_tensor, origin_tensor = transform(pil_image)
|
147 |
+
white_matter_tensor, gray_matter_tensor, origin_tensor = white_matter_tensor.unsqueeze(0).to(torch.float32),\
|
148 |
+
gray_matter_tensor.unsqueeze(0).to(torch.float32),\
|
149 |
+
origin_tensor.unsqueeze(0).to(torch.float32)
|
150 |
+
|
151 |
+
def calculate_gradcammask(model_grad, input_tensor):
|
152 |
+
target_layer = [model_grad.layer4[-1]]
|
153 |
+
gradcam = GradCAM(model=model_grad, target_layers=target_layer)
|
154 |
+
targets = [ClassifierOutputTarget(target)]
|
155 |
+
grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets, aug_smooth=True, eigen_smooth=True)
|
156 |
+
grayscale_cam = torch.tensor(grayscale_cam)
|
157 |
+
|
158 |
+
return grayscale_cam
|
159 |
+
|
160 |
+
origin_model = model.resnet18_model
|
161 |
+
white_model = model.whitematter_resnet18_model
|
162 |
+
gray_model = model.graymatter_resnet18_model
|
163 |
+
|
164 |
+
origin_cam = calculate_gradcammask(origin_model, origin_tensor)
|
165 |
+
white_cam = calculate_gradcammask(white_model, white_matter_tensor)
|
166 |
+
gray_cam = calculate_gradcammask(gray_model, gray_matter_tensor)
|
167 |
+
|
168 |
+
class_idx = {0: 'Moderate Demented', 1: 'Mild Demented', 2: 'Very Mild Demented', 3: 'Non Demented'}
|
169 |
+
prediction = model(white_matter_tensor, gray_matter_tensor, origin_tensor)
|
170 |
+
predicted_class_index = torch.argmax(prediction).item()
|
171 |
+
predicted_class_label = class_idx[predicted_class_index]
|
172 |
+
|
173 |
+
if plot_type == 'withmask':
|
174 |
+
return predicted_class_label, showcam_withmask(torch.squeeze(origin_tensor), origin_cam),\
|
175 |
+
showcam_withmask(torch.squeeze(white_matter_tensor), white_cam, image_title='White Matter'),\
|
176 |
+
showcam_withmask(torch.squeeze(gray_matter_tensor), gray_cam, image_title='Gray Matter')
|
177 |
+
elif plot_type == 'withoutmask':
|
178 |
+
return predicted_class_label, showcam_withoutmask(torch.squeeze(origin_tensor),origin_cam),\
|
179 |
+
showcam_withoutmask(torch.squeeze(white_matter_tensor),white_cam, image_title='White Matter'),\
|
180 |
+
showcam_withoutmask(torch.squeeze(gray_matter_tensor),gray_cam , image_title='Gray Matter')
|
181 |
+
else:
|
182 |
+
raise ValueError("plot_type must be either 'withmask' or 'withoutmask'")
|
verymild.jpg
ADDED
![]() |