Jiranuwat commited on
Commit
201936b
·
verified ·
1 Parent(s): 9e7fdb8

Upload 10 files

Browse files
Files changed (10) hide show
  1. AlzheimerModelCPU.pth +3 -0
  2. AlzheimerTriMatterNet.py +24 -0
  3. Resnet18.py +100 -0
  4. app.py +93 -0
  5. mild.jpg +0 -0
  6. moderate.jpg +0 -0
  7. non.jpg +0 -0
  8. requirements.txt +7 -0
  9. utils.py +182 -0
  10. verymild.jpg +0 -0
AlzheimerModelCPU.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2241af42d3188aea27dcc36259f72e57a6ad06b555e1687c3f298e52c3128ce8
3
+ size 134441434
AlzheimerTriMatterNet.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torchvision.models as models
3
+ from Resnet18 import *
4
+
5
+ class AlzheimerTriMatterNet(nn.Module):
6
+ def __init__(self):
7
+ super(AlzheimerTriMatterNet, self).__init__()
8
+ self.numclass = 4
9
+ self.whitematter_resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
10
+ self.graymatter_resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
11
+ self.resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
12
+ self.global_classification_head = nn.Sequential(
13
+ nn.Linear(512*3,self.numclass),
14
+ nn.Softmax(dim=1),
15
+ )
16
+
17
+ def forward(self, whitematter, graymatter, original):
18
+ white_output = self.whitematter_resnet18_model(whitematter)
19
+ gray_output = self.graymatter_resnet18_model(graymatter)
20
+ origin_output = self.resnet18_model(original)
21
+ combined_tensor = torch.cat(( white_output, gray_output, origin_output), dim=1)
22
+ output = self.global_classification_head(combined_tensor)
23
+ return output
24
+
Resnet18.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Code from : https://debuggercafe.com/implementing-resnet18-in-pytorch-from-scratch/"""
2
+ import torch.nn as nn
3
+ import torch
4
+ from torchvision.ops import RoIPool
5
+ from torch import Tensor
6
+ from typing import Type
7
+
8
+ class BasicBlock(nn.Module):
9
+ def __init__(self, in_channels: int,out_channels: int,stride: int = 1,expansion: int = 1,downsample: nn.Module = None) -> None:
10
+ super(BasicBlock, self).__init__()
11
+ # Multiplicative factor for the subsequent conv2d layer's output channels.
12
+ # It is 1 for ResNet18 and ResNet34.
13
+ self.expansion = expansion
14
+ self.downsample = downsample
15
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1,bias=False)
16
+ self.bn1 = nn.BatchNorm2d(out_channels)
17
+ self.relu = nn.ReLU(inplace=True)
18
+ self.conv2 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=3, padding=1,bias=False)
19
+ self.bn2 = nn.BatchNorm2d(out_channels*self.expansion)
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ identity = x
23
+ out = self.conv1(x)
24
+ out = self.bn1(out)
25
+ out = self.relu(out)
26
+ out = self.conv2(out)
27
+ out = self.bn2(out)
28
+ if self.downsample is not None:
29
+ identity = self.downsample(x)
30
+ out += identity
31
+ out = self.relu(out)
32
+ return out
33
+
34
+ class ResNet18(nn.Module):
35
+ def __init__(self, img_channels: int,num_layers: int,block: Type[BasicBlock],num_classes: int = 1000) -> None:
36
+ super(ResNet18, self).__init__()
37
+ if num_layers == 18:
38
+ # The following `layers` list defines the number of `BasicBlock`
39
+ # to use to build the network and how many basic blocks to stack
40
+ # together.
41
+ layers = [2, 2, 2, 2]
42
+ self.expansion = 1
43
+
44
+ self.in_channels = 64
45
+ # All ResNets (18 to 152) contain a Conv2d => BN => ReLU for the first
46
+ # three layers. Here, kernel size is 7.
47
+ self.conv1 = nn.Conv2d(in_channels=img_channels,out_channels=self.in_channels,kernel_size=7, stride=2,padding=3,bias=False)
48
+ self.bn1 = nn.BatchNorm2d(self.in_channels)
49
+ self.relu = nn.ReLU(inplace=True)
50
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
51
+ self.layer1 = self._make_layer(block, 64, layers[0])
52
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
53
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
54
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
55
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
56
+ self.fc = nn.Linear(512*self.expansion, num_classes)
57
+
58
+ def _make_layer(self, block: Type[BasicBlock],out_channels: int,blocks: int,stride: int = 1) -> nn.Sequential:
59
+ downsample = None
60
+ if stride != 1:
61
+ """
62
+ This should pass from `layer2` to `layer4` or
63
+ when building ResNets50 and above. Section 3.3 of the paper
64
+ Deep Residual Learning for Image Recognition
65
+ (https://arxiv.org/pdf/1512.03385v1.pdf).
66
+ """
67
+ downsample = nn.Sequential(
68
+ nn.Conv2d(self.in_channels, out_channels*self.expansion,kernel_size=1,stride=stride,bias=False),
69
+ nn.BatchNorm2d(out_channels * self.expansion),
70
+ )
71
+ layers = []
72
+ layers.append(
73
+ block(
74
+ self.in_channels, out_channels, stride, self.expansion, downsample
75
+ )
76
+ )
77
+ self.in_channels = out_channels * self.expansion
78
+ for i in range(1, blocks):
79
+ layers.append(block(
80
+ self.in_channels,
81
+ out_channels,
82
+ expansion=self.expansion
83
+ ))
84
+ return nn.Sequential(*layers)
85
+ def forward(self, x: Tensor) -> Tensor:
86
+ x = self.conv1(x)
87
+ x = self.bn1(x)
88
+ x = self.relu(x)
89
+ x = self.maxpool(x)
90
+ x = self.layer1(x)
91
+ x = self.layer2(x)
92
+ x = self.layer3(x)
93
+ x = self.layer4(x)
94
+ # The spatial dimension of the final layer's feature
95
+ # map should be (7, 7) for all ResNets.
96
+ #print('Dimensions of the last convolutional feature map: ', x.shape)
97
+ x = self.avgpool(x)
98
+ x = torch.flatten(x, 1)
99
+ #x = self.fc(x)
100
+ return x
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from utils import *
4
+ import os
5
+
6
+ model = torch.load('AlzheimerModelCPU.pth')
7
+ model.eval()
8
+
9
+ def reset():
10
+ return None, None, None, None, None
11
+
12
+ with gr.Blocks() as demo:
13
+ gr.HTML("""
14
+ <h1 style="text-align: center; font-size: 50px;">
15
+ Alzheimer Detection
16
+ </h1>
17
+ <p style="text-align: center;">
18
+ Early Detection of Alzheimer's Disease: A Deep Learning Approach for Accurate Diagnosis.
19
+ </p>
20
+ <h3 style="text-align: left;">
21
+ To use the demo, please follow the steps below.
22
+ </h3>
23
+ <ul>
24
+ <li style="text-align: left; padding: 0px 0px 0px 30px;">
25
+ If you want to try examples, click one of <span style="font-weight: bold;">Examples</span> images below. Then, click
26
+ <span style="font-weight: bold;">Submit</span>.
27
+ </li>
28
+ <li style="text-align: left; padding: 0px 0px 0px 30px;">
29
+ if you don't want to try examples, upload an image and click <span style="font-weight: bold;">Submit</span>.
30
+ </li>
31
+ <li style="text-align: left; padding: 0px 0px 0px 30px;">
32
+ You can adjust the <span style="font-weight: bold;">Target</span> for your desire visualization and
33
+ <span style="font-weight: bold;">Plot Type</span> between <span style="font-weight: bold;">withmask</span> and <span style="font-weight: bold;">withoutmask</span>
34
+ to plot original images with or without Grad-CAM.
35
+ </li>
36
+ <li style="text-align: left; padding: 0px 0px 0px 30px;">
37
+ If you want to reset all components, click <span style="font-weight: bold;">Reset All Components</span> button.
38
+ </li>
39
+ </ul>
40
+ """
41
+ )
42
+ with gr.Row():
43
+
44
+ with gr.Column():
45
+ image_area = gr.Image(sources=["upload"], type="pil", scale=2, label="Upload Image")
46
+
47
+ with gr.Row():
48
+ choosen_plottype =gr.Radio(choices=["withmask", "withoutmask"], value="withmask", label="Plot Type", scale=1, interactive=True)
49
+ choosen_target = gr.Slider(minimum=0, maximum=200, step=1, value=100, label="Classifier OutputTarget", scale=1, interactive=True)
50
+ submit_btn = gr.Button("Submit", variant='primary', scale=1)
51
+
52
+ with gr.Column():
53
+ gr.HTML("""
54
+ <h2 style="text-align: center;">
55
+
56
+ </h2>
57
+ <h2 style="text-align: center;">
58
+ Output and Prediction
59
+ </h2>
60
+ <p style="text-align: center;">
61
+ Grad-CAM for 3 images. Original, White Matter, Gray Matter image.
62
+ </p>
63
+ """
64
+ )
65
+ text_area = gr.Textbox(label="Prediction Result")
66
+ plotarea_original = gr.Plot(label="Original image")
67
+ plotarea_white = gr.Plot(label="White Matter image")
68
+ plotarea_gray = gr.Plot(label="Gray Matter image")
69
+
70
+ reset_btn = gr.Button("Reset All Components", variant='stop', scale=1)
71
+
72
+ gr.HTML("""
73
+ <h2 style="text-align: left;">
74
+ Examples
75
+ </h2>
76
+ <p style="text-align: left;">
77
+ You can select 1 image from the examples and click "Submit".
78
+ </p>
79
+ """
80
+ )
81
+ examples = gr.Examples(examples=[ os.path.join(os.getcwd(), "examples\\non.jpg"),
82
+ os.path.join(os.getcwd(),"examples\\verymild.jpg"),
83
+ os.path.join(os.getcwd(), "examples\\mild.jpg"),
84
+ os.path.join(os.getcwd(), "examples\\moderate.jpg")],
85
+ inputs=image_area,
86
+ outputs=image_area,
87
+ label="Examples",
88
+ )
89
+
90
+ submit_btn.click(lambda x, target, plot_type: predict_and_gradcam(x, model=model, target=target, plot_type=plot_type), inputs=[image_area, choosen_target, choosen_plottype], outputs=[text_area, plotarea_original, plotarea_white, plotarea_gray])
91
+ reset_btn.click(reset, outputs=[image_area, text_area, plotarea_original, plotarea_white, plotarea_gray])
92
+
93
+ demo.launch()
mild.jpg ADDED
moderate.jpg ADDED
non.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ matplotlib
5
+ opencv-python
6
+ numpy
7
+ grad-cam
utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import numpy as np
6
+ from pytorch_grad_cam import GradCAM
7
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
8
+
9
+ class split_white_and_gray():
10
+ def __init__(self,threshold=120) -> None:
11
+ """
12
+ Initialize the class with a threshold value.
13
+
14
+ Args:
15
+ threshold (int, optional): The threshold value to be set. Defaults to 120.
16
+ """
17
+ self.threshold = threshold
18
+
19
+ def __call__(self,tensor):
20
+ """
21
+ Apply thresholding to the input tensor and return the white matter, gray matter, and the original tensor.
22
+
23
+ Parameters:
24
+ tensor (torch.Tensor): The input tensor to be thresholded.
25
+
26
+ Returns:
27
+ torch.Tensor: The thresholded white matter.
28
+ torch.Tensor: The thresholded gray matter.
29
+ torch.Tensor: The original input tensor.
30
+ """
31
+ tensor = (tensor*255).to(torch.int64)
32
+
33
+ # Apply thresholding
34
+ white_matter = torch.where(tensor >= self.threshold,tensor,0)
35
+ white_matter = (white_matter/255).to(torch.float64)
36
+ gray_matter = torch.where(tensor < self.threshold,tensor,0)
37
+ gray_matter = (gray_matter/255).to(torch.float64)
38
+ tensor = (tensor/255).to(torch.float64)
39
+
40
+ return white_matter, gray_matter,tensor
41
+
42
+ def showcam_withoutmask(original_image, grayscale_cam, image_title='Original Image'):
43
+ """This function applies the CAM mask to the original image and returns the Matplotlib Figure object.
44
+
45
+ :param original_image: The original image tensor in PyTorch format.
46
+ :param grayscale_cam: The CAM mask tensor in PyTorch format.
47
+
48
+ :return: Matplotlib Figure object.
49
+ """
50
+ # Assuming you have two tensors: 'original_image' and 'cam_mask'
51
+ # Make sure both tensors are on the CPU
52
+ original_image = torch.squeeze(original_image).cpu() # torch.Size([3, 150, 150])
53
+ cam_mask = grayscale_cam.cpu() # torch.Size([1, 150, 150])
54
+
55
+ # Convert the tensors to NumPy arrays
56
+ original_image_np = original_image.numpy()
57
+ cam_mask_np = cam_mask.numpy()
58
+
59
+ # Apply the mask to the original image
60
+ masked_image = original_image_np * cam_mask_np
61
+
62
+ # Normalize the masked_image
63
+ masked_image_norm = (masked_image - np.min(masked_image)) / (np.max(masked_image) - np.min(masked_image))
64
+
65
+ # Create Matplotlib Figure
66
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
67
+
68
+ # Plot the original image
69
+ axes[0].imshow(original_image_np.transpose(1, 2, 0)) # Assuming your original image is in (C, H, W) format
70
+ axes[0].set_title(image_title)
71
+
72
+ # Plot the CAM mask
73
+ axes[1].imshow(cam_mask_np[0], cmap='jet') # Assuming your mask is grayscale
74
+ axes[1].set_title('CAM Mask')
75
+
76
+ # Plot the overlay (normalized)
77
+ axes[2].imshow(masked_image_norm.transpose(1, 2, 0)) # Assuming your original image is in (C, H, W) format
78
+ axes[2].set_title('Overlay (Normalized)')
79
+
80
+ return fig
81
+
82
+ def showcam_withmask(img_tensor: torch.Tensor,
83
+ mask_tensor: torch.Tensor,
84
+ use_rgb: bool = False,
85
+ colormap: int = cv2.COLORMAP_JET,
86
+ image_weight: float = 0.5,
87
+ image_title: str = 'Original Image') -> plt.Figure:
88
+ """ This function overlays the CAM mask on the image as a heatmap and returns the Figure object.
89
+ By default, the heatmap is in BGR format.
90
+
91
+ :param img_tensor: The base image tensor in PyTorch format.
92
+ :param mask_tensor: The CAM mask tensor in PyTorch format.
93
+ :param use_rgb: Whether to use an RGB or BGR heatmap; set to True if 'img_tensor' is in RGB format.
94
+ :param colormap: The OpenCV colormap to be used.
95
+ :param image_weight: The final result is image_weight * img + (1-image_weight) * mask.
96
+
97
+ :return: Matplotlib Figure object.
98
+ """
99
+ # Convert PyTorch tensors to NumPy arrays
100
+ img = img_tensor.cpu().numpy().transpose(1, 2, 0)
101
+ mask = mask_tensor.cpu().numpy()
102
+
103
+ # Convert the mask to a single-channel image
104
+ mask_single_channel = np.uint8(255 * mask[0])
105
+
106
+ heatmap = cv2.applyColorMap(mask_single_channel, colormap)
107
+
108
+ if use_rgb:
109
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
110
+
111
+ heatmap = np.float32(heatmap) / 255
112
+
113
+ if np.max(img) > 1:
114
+ raise Exception("The input image should be in the range [0, 1]")
115
+
116
+ if image_weight < 0 or image_weight > 1:
117
+ raise Exception(f"image_weight should be in the range [0, 1]. Got: {image_weight}")
118
+
119
+ cam = (1 - image_weight) * heatmap + image_weight * img
120
+ cam = cam / np.max(cam)
121
+
122
+ # Create Matplotlib Figure
123
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
124
+
125
+ # Plot the original image
126
+ axes[0].imshow(img)
127
+ axes[0].set_title(image_title)
128
+
129
+ # Plot the CAM mask
130
+ axes[1].imshow(mask[0], cmap='jet')
131
+ axes[1].set_title('CAM Mask')
132
+
133
+ # Plot the overlay
134
+ axes[2].imshow(cam)
135
+ axes[2].set_title('Overlay')
136
+
137
+ return fig
138
+
139
+ def predict_and_gradcam(pil_image, model, target=100, plot_type='withmask'):
140
+ transform = transforms.Compose([
141
+ transforms.Resize((150, 150)),
142
+ transforms.Grayscale(num_output_channels=3),
143
+ transforms.ToTensor(),
144
+ split_white_and_gray(120),
145
+ ])
146
+ white_matter_tensor, gray_matter_tensor, origin_tensor = transform(pil_image)
147
+ white_matter_tensor, gray_matter_tensor, origin_tensor = white_matter_tensor.unsqueeze(0).to(torch.float32),\
148
+ gray_matter_tensor.unsqueeze(0).to(torch.float32),\
149
+ origin_tensor.unsqueeze(0).to(torch.float32)
150
+
151
+ def calculate_gradcammask(model_grad, input_tensor):
152
+ target_layer = [model_grad.layer4[-1]]
153
+ gradcam = GradCAM(model=model_grad, target_layers=target_layer)
154
+ targets = [ClassifierOutputTarget(target)]
155
+ grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets, aug_smooth=True, eigen_smooth=True)
156
+ grayscale_cam = torch.tensor(grayscale_cam)
157
+
158
+ return grayscale_cam
159
+
160
+ origin_model = model.resnet18_model
161
+ white_model = model.whitematter_resnet18_model
162
+ gray_model = model.graymatter_resnet18_model
163
+
164
+ origin_cam = calculate_gradcammask(origin_model, origin_tensor)
165
+ white_cam = calculate_gradcammask(white_model, white_matter_tensor)
166
+ gray_cam = calculate_gradcammask(gray_model, gray_matter_tensor)
167
+
168
+ class_idx = {0: 'Moderate Demented', 1: 'Mild Demented', 2: 'Very Mild Demented', 3: 'Non Demented'}
169
+ prediction = model(white_matter_tensor, gray_matter_tensor, origin_tensor)
170
+ predicted_class_index = torch.argmax(prediction).item()
171
+ predicted_class_label = class_idx[predicted_class_index]
172
+
173
+ if plot_type == 'withmask':
174
+ return predicted_class_label, showcam_withmask(torch.squeeze(origin_tensor), origin_cam),\
175
+ showcam_withmask(torch.squeeze(white_matter_tensor), white_cam, image_title='White Matter'),\
176
+ showcam_withmask(torch.squeeze(gray_matter_tensor), gray_cam, image_title='Gray Matter')
177
+ elif plot_type == 'withoutmask':
178
+ return predicted_class_label, showcam_withoutmask(torch.squeeze(origin_tensor),origin_cam),\
179
+ showcam_withoutmask(torch.squeeze(white_matter_tensor),white_cam, image_title='White Matter'),\
180
+ showcam_withoutmask(torch.squeeze(gray_matter_tensor),gray_cam , image_title='Gray Matter')
181
+ else:
182
+ raise ValueError("plot_type must be either 'withmask' or 'withoutmask'")
verymild.jpg ADDED