Shilpaj commited on
Commit
327f84a
·
1 Parent(s): 41dbd0c

Feat: App file and visualize.py

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +237 -0
  3. visualize.py +386 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea/
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio Application for model trained on CIFAR10 dataset
4
+ Author: Shilpaj Bhalerao
5
+ Date: Aug 06, 2023
6
+ """
7
+ # Standard Library Imports
8
+ import os
9
+ from collections import OrderedDict
10
+
11
+ # Third-Party Imports
12
+ import gradio as gr
13
+ import numpy as np
14
+ import torch
15
+ from torchvision import transforms
16
+ from pytorch_grad_cam import GradCAM
17
+ from pytorch_grad_cam.utils.image import show_cam_on_image
18
+
19
+ # Local Imports
20
+ from resnet import LITResNet
21
+ from visualize import FeatureMapVisualizer
22
+
23
+ # Directory Path
24
+ example_directory = 'examples/'
25
+ model_path = 'epoch=23-step=2112.ckpt'
26
+
27
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
28
+ 'dog', 'frog', 'horse', 'ship', 'truck')
29
+
30
+ model = LITResNet.load_from_checkpoint(model_path, map_location=torch.device('cpu'), strict=False, class_names=classes)
31
+ model.eval()
32
+
33
+ # Create an object of the Class
34
+ viz = FeatureMapVisualizer(model)
35
+
36
+
37
+ def inference(input_img,
38
+ transparency=0.5,
39
+ number_of_top_classes=3,
40
+ target_layer_number=4):
41
+ """
42
+ Function to run inference on the input image
43
+ :param input_img: Image provided by the user
44
+ :parma transparency: Percentage of cam overlap over the input image
45
+ :param number_of_top_classes: Number of top predictions for the input image
46
+ :param target_layer_number: Layer for which GradCam to be shown
47
+ """
48
+ # Calculate mean over each channel of input image
49
+ mean_r, mean_g, mean_b = np.mean(input_img[:, :, 0]/255.), np.mean(input_img[:, :, 1]/255.), np.mean(input_img[:, :, 2]/255.)
50
+
51
+ # Calculate Standard deviation over each channel
52
+ std_r, std_g, std_b = np.std(input_img[:, :, 0]/255.), np.std(input_img[:, :, 1]/255.), np.std(input_img[:, :, 2]/255.)
53
+
54
+ # Convert img to tensor and normalize it
55
+ _transform = transforms.Compose([
56
+ transforms.ToTensor(),
57
+ transforms.Normalize((mean_r, mean_g, mean_b), (std_r, std_g, std_b))
58
+ ])
59
+
60
+ # Save a copy of input img
61
+ org_img = input_img
62
+
63
+ # Apply the transforms on the input image
64
+ input_img = _transform(input_img)
65
+
66
+ # Add batch dimension to perform inference
67
+ input_img = input_img.unsqueeze(0)
68
+
69
+ # Get Model Predictions
70
+ with torch.no_grad():
71
+ outputs = model(input_img)
72
+ o = torch.exp(outputs)[0]
73
+ confidences = {classes[i]: float(o[i]) for i in range(10)}
74
+
75
+ # Select the top classes based on user input
76
+ sorted_confidences = sorted(confidences.items(), key=lambda val: val[1], reverse=True)
77
+ show_confidences = OrderedDict(sorted_confidences[:number_of_top_classes])
78
+
79
+ # Name of layers defined in the model
80
+ _layers = ['prep_layer', 'custom_block1', 'resnet_block1',
81
+ 'custom_block2', 'custom_block3', 'resnet_block3']
82
+ target_layers = [eval(f'model.{_layers[target_layer_number-1]}[0]')]
83
+
84
+ # Get the class activations from the selected layer
85
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
86
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
87
+ grayscale_cam = grayscale_cam[0, :]
88
+
89
+ # Overlay input image with Class activations
90
+ visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
91
+ return show_confidences, visualization
92
+
93
+
94
+ def display_misclassified_images(number: int = 1):
95
+ """
96
+ Display the misclassified images saved during training
97
+ :param number: Number of images to display
98
+ """
99
+ # List to store names of misclassified images
100
+ data = []
101
+
102
+ # Get the names of all the files from Misclassified directory
103
+ file_names = os.listdir('misclassified/')
104
+
105
+ # Save the correct name and misclassified class name as a tuple in the `data` list
106
+ for file in file_names:
107
+ file_name, extension = file.split('.')
108
+ correct_label, misclassified = file_name.split('_')
109
+ data.append((correct_label, misclassified))
110
+
111
+ # Create a path to the images for Gradio to access them
112
+ file_path = ['misclassified/' + file for file in file_names]
113
+
114
+ # Return the file path and names of correct and misclassified images
115
+ return file_path[:number], data[:number]
116
+
117
+
118
+ def feature_maps(input_img, kernel_number=32):
119
+ """
120
+ Function to return feature maps for the selected image
121
+ :param input_img: User input image
122
+ :param kernel_number: Number of kernel in all 6 layers
123
+ """
124
+ # Calculate mean over each channel of input image
125
+ mean_r, mean_g, mean_b = np.mean(input_img[:, :, 0]/255.), np.mean(input_img[:, :, 1]/255.), np.mean(input_img[:, :, 2]/255.)
126
+
127
+ # Calculate Standard deviation over each channel
128
+ std_r, std_g, std_b = np.std(input_img[:, :, 0]/255.), np.std(input_img[:, :, 1]/255.), np.std(input_img[:, :, 2]/255.)
129
+
130
+ # Convert img to tensor and normalize it
131
+ _transform = transforms.Compose([
132
+ transforms.ToTensor(),
133
+ transforms.Normalize((mean_r, mean_g, mean_b), (std_r, std_g, std_b))
134
+ ])
135
+
136
+ # Apply transforms on the input image
137
+ input_img = _transform(input_img)
138
+
139
+ # Visualize feature maps for kernel number 32
140
+ plt = viz.visualize_feature_map_of_kernel(image=input_img, kernel_number=kernel_number)
141
+ return plt
142
+
143
+
144
+ def get_kernels(layer_number):
145
+ """
146
+ Function to get the kernels from the layer
147
+ :param layer_number: Number of layer from which kernels to be visualized
148
+ """
149
+ # Visualize kernels from layer
150
+ plt = viz.visualize_kernels_from_layer(layer_number=layer_number)
151
+ return plt
152
+
153
+
154
+ if __name__ == '__main__':
155
+ with gr.Blocks() as demo:
156
+ gr.Markdown(
157
+ """
158
+ # CIFAR10 trained on ResNet18 Model
159
+ - A model architecture by [David C](https://github.com/davidcpage) which is trained on CIFAR10 for 24 Epochs to achieve accuracy of 90+%
160
+ - One Cycle Policy is used during training to speed up the trainig process
161
+ - The model works for following classes: `plane`, `car`, `bird`, `cat`, `deer`, `dog`, `frog`, `horse`, `ship`, `truck`
162
+
163
+ ### A simple Gradio interface
164
+ - To infer what exactly the model is looking at using GradCAM results
165
+ - To display the misclassified images from the 10% of test data of CIFAR10 dataset
166
+ - To visualize the feature maps from each of the six convolutional block's first layer
167
+ - To visualize the kernels from each of the six convolutional block's first layer
168
+ """
169
+ )
170
+
171
+ # #############################################################################
172
+ # ################################ GradCam Tab ################################
173
+ # #############################################################################
174
+ with gr.Tab("GradCam"):
175
+ with gr.Row():
176
+ img_input = [gr.Image(shape=(32, 32), label="Input Image")]
177
+ gradcam_outputs = [gr.Label(),
178
+ gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)]
179
+
180
+ with gr.Row():
181
+ gradcam_inputs = [gr.Slider(0, 1, value=0.5,
182
+ label="How much percentage overlap of what model is looking at in the image?"),
183
+ gr.Slider(1, 10, value=3, step=1, label="How many top predictions you want to see?"),
184
+ gr.Slider(1, 6, value=4, step=1,
185
+ label="From 6 layers of the model, which layer's class activation you want to see?")]
186
+
187
+ gradcam_button = gr.Button("Submit")
188
+ gradcam_button.click(inference, inputs=img_input + gradcam_inputs, outputs=gradcam_outputs)
189
+
190
+ gr.Markdown("## Examples")
191
+ gr.Examples([example_directory + 'dog.jpg', example_directory + 'cat.jpg', example_directory + 'frog.jpg', example_directory + 'bird.jpg', example_directory + 'shark-plane.jpg',
192
+ example_directory + 'car.jpg', example_directory + 'truck.jpg', example_directory + 'horse.jpg', example_directory + 'plane.jpg', example_directory + 'ship.png'],
193
+ inputs=img_input, fn=inference)
194
+
195
+ # ###########################################################################################
196
+ # ################################ Misclassified Images Tab #################################
197
+ # ###########################################################################################
198
+ with gr.Tab("Misclassified Images"):
199
+ with gr.Row():
200
+ mis_inputs = [gr.Slider(1, 10, value=1, step=1,
201
+ label="Select the Number of Misclassified Images you want to see")]
202
+ mis_outputs = [
203
+ gr.Gallery(label="Misclassified Images", show_label=False, elem_id="gallery").style(columns=[2],
204
+ rows=[2],
205
+ object_fit="contain",
206
+ height="auto"),
207
+ gr.Dataframe(headers=["Correct Label", "Misclassified Label"], type="array", datatype="str",
208
+ row_count=10, col_count=2)]
209
+ mis_button = gr.Button("Display Misclassified Images")
210
+ mis_button.click(display_misclassified_images, inputs=mis_inputs, outputs=mis_outputs)
211
+
212
+ # ################################################################################################
213
+ # ################################ Feature Maps Visualization Tab ################################
214
+ # ################################################################################################
215
+ with gr.Tab("Feature Map Visualization"):
216
+ with gr.Column():
217
+ feature_map_input = [gr.Image(shape=(32, 32), label="Feature Map Input Image"),
218
+ gr.Slider(1, 32, value=16, step=1,
219
+ label="Select a Kernel number for which features maps from all 6 layers to be shown")]
220
+ map = gr.Plot().style()
221
+ feature_map_button = gr.Button("Visualize FeatureMaps")
222
+ feature_map_button.click(feature_maps, inputs=feature_map_input, outputs=map)
223
+
224
+ # ##########################################################################################
225
+ # ################################ Kernel Visualization Tab ################################
226
+ # ##########################################################################################
227
+ with gr.Tab("Kernel Visualization"):
228
+ with gr.Column():
229
+ kernel_input = [
230
+ gr.Slider(1, 4, value=2, step=1, label="Select a layer number from which the kernels to be shown")]
231
+ map = gr.Plot().style()
232
+ kernel_button = gr.Button("Visualize Kernels")
233
+
234
+ kernel_button.click(get_kernels, inputs=kernel_input, outputs=map)
235
+
236
+ gr.close_all()
237
+ demo.launch(debug=True, share=True)
visualize.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Function used for visualization of data and results
4
+ Author: Shilpaj Bhalerao
5
+ Date: Jun 21, 2023
6
+ """
7
+ # Standard Library Imports
8
+ import math
9
+ from dataclasses import dataclass
10
+ from typing import NoReturn
11
+
12
+ # Third-Party Imports
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import pandas as pd
16
+ import seaborn as sn
17
+ import torch
18
+ import torch.nn as nn
19
+ from torchvision import transforms
20
+ from sklearn.metrics import confusion_matrix
21
+
22
+
23
+ # ---------------------------- DATA SAMPLES ----------------------------
24
+ def display_mnist_data_samples(dataset: 'DataLoader object', number_of_samples: int) -> NoReturn:
25
+ """
26
+ Function to display samples for dataloader
27
+ :param dataset: Train or Test dataset transformed to Tensor
28
+ :param number_of_samples: Number of samples to be displayed
29
+ """
30
+ # Get batch from the data_set
31
+ batch_data = []
32
+ batch_label = []
33
+ for count, item in enumerate(dataset):
34
+ if not count <= number_of_samples:
35
+ break
36
+ batch_data.append(item[0])
37
+ batch_label.append(item[1])
38
+
39
+ # Plot the samples from the batch
40
+ fig = plt.figure()
41
+ x_count = 5
42
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
43
+
44
+ # Plot the samples from the batch
45
+ for i in range(number_of_samples):
46
+ plt.subplot(y_count, x_count, i + 1)
47
+ plt.tight_layout()
48
+ plt.imshow(batch_data[i].squeeze(), cmap='gray')
49
+ plt.title(batch_label[i])
50
+ plt.xticks([])
51
+ plt.yticks([])
52
+
53
+
54
+ def display_cifar_data_samples(data_set, number_of_samples: int, classes: list):
55
+ """
56
+ Function to display samples for data_set
57
+ :param data_set: Train or Test data_set transformed to Tensor
58
+ :param number_of_samples: Number of samples to be displayed
59
+ :param classes: Name of classes to be displayed
60
+ """
61
+ # Get batch from the data_set
62
+ batch_data = []
63
+ batch_label = []
64
+ for count, item in enumerate(data_set):
65
+ if not count <= number_of_samples:
66
+ break
67
+ batch_data.append(item[0])
68
+ batch_label.append(item[1])
69
+ batch_data = torch.stack(batch_data, dim=0).numpy()
70
+
71
+ # Plot the samples from the batch
72
+ fig = plt.figure()
73
+ x_count = 5
74
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
75
+
76
+ for i in range(number_of_samples):
77
+ plt.subplot(y_count, x_count, i + 1)
78
+ plt.tight_layout()
79
+ plt.imshow(np.transpose(batch_data[i].squeeze(), (1, 2, 0)))
80
+ plt.title(classes[batch_label[i]])
81
+ plt.xticks([])
82
+ plt.yticks([])
83
+
84
+
85
+ # ---------------------------- MISCLASSIFIED DATA ----------------------------
86
+ def display_cifar_misclassified_data(data: list,
87
+ classes: list[str],
88
+ inv_normalize: transforms.Normalize,
89
+ number_of_samples: int = 10):
90
+ """
91
+ Function to plot images with labels
92
+ :param data: List[Tuple(image, label)]
93
+ :param classes: Name of classes in the dataset
94
+ :param inv_normalize: Mean and Standard deviation values of the dataset
95
+ :param number_of_samples: Number of images to print
96
+ """
97
+ fig = plt.figure(figsize=(10, 10))
98
+
99
+ x_count = 5
100
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
101
+
102
+ for i in range(number_of_samples):
103
+ plt.subplot(y_count, x_count, i + 1)
104
+ img = data[i][0].squeeze().to('cpu')
105
+ img = inv_normalize(img)
106
+ plt.imshow(np.transpose(img, (1, 2, 0)))
107
+ plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
108
+ plt.xticks([])
109
+ plt.yticks([])
110
+
111
+
112
+ def display_mnist_misclassified_data(data: list,
113
+ number_of_samples: int = 10):
114
+ """
115
+ Function to plot images with labels
116
+ :param data: List[Tuple(image, label)]
117
+ :param number_of_samples: Number of images to print
118
+ """
119
+ fig = plt.figure(figsize=(8, 5))
120
+
121
+ x_count = 5
122
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
123
+
124
+ for i in range(number_of_samples):
125
+ plt.subplot(y_count, x_count, i + 1)
126
+ img = data[i][0].squeeze(0).to('cpu')
127
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
128
+ plt.title(r"Correct: " + str(data[i][1].item()) + '\n' + 'Output: ' + str(data[i][2].item()))
129
+ plt.xticks([])
130
+ plt.yticks([])
131
+
132
+
133
+ # ---------------------------- AUGMENTATION SAMPLES ----------------------------
134
+ def visualize_cifar_augmentation(data_set, data_transforms):
135
+ """
136
+ Function to visualize the augmented data
137
+ :param data_set: Dataset without transformations
138
+ :param data_transforms: Dictionary of transforms
139
+ """
140
+ sample, label = data_set[6]
141
+ total_augmentations = len(data_transforms)
142
+
143
+ fig = plt.figure(figsize=(10, 5))
144
+ for count, (key, trans) in enumerate(data_transforms.items()):
145
+ if count == total_augmentations - 1:
146
+ break
147
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
148
+ augmented = trans(image=sample)['image']
149
+ plt.imshow(augmented)
150
+ plt.title(key)
151
+ plt.xticks([])
152
+ plt.yticks([])
153
+
154
+
155
+ def visualize_mnist_augmentation(data_set, data_transforms):
156
+ """
157
+ Function to visualize the augmented data
158
+ :param data_set: Dataset to visualize the augmentations
159
+ :param data_transforms: Dictionary of transforms
160
+ """
161
+ sample, label = data_set[6]
162
+ total_augmentations = len(data_transforms)
163
+
164
+ fig = plt.figure(figsize=(10, 5))
165
+ for count, (key, trans) in enumerate(data_transforms.items()):
166
+ if count == total_augmentations - 1:
167
+ break
168
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
169
+ img = trans(sample).to('cpu')
170
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
171
+ plt.title(key)
172
+ plt.xticks([])
173
+ plt.yticks([])
174
+
175
+
176
+ # ---------------------------- LOSS AND ACCURACIES ----------------------------
177
+ def display_loss_and_accuracies(train_losses: list,
178
+ train_acc: list,
179
+ test_losses: list,
180
+ test_acc: list,
181
+ plot_size: tuple = (10, 10)) -> NoReturn:
182
+ """
183
+ Function to display training and test information(losses and accuracies)
184
+ :param train_losses: List containing training loss of each epoch
185
+ :param train_acc: List containing training accuracy of each epoch
186
+ :param test_losses: List containing test loss of each epoch
187
+ :param test_acc: List containing test accuracy of each epoch
188
+ :param plot_size: Size of the plot
189
+ """
190
+ # Create a plot of 2x2 of size
191
+ fig, axs = plt.subplots(2, 2, figsize=plot_size)
192
+
193
+ # Plot the training loss and accuracy for each epoch
194
+ axs[0, 0].plot(train_losses)
195
+ axs[0, 0].set_title("Training Loss")
196
+ axs[1, 0].plot(train_acc)
197
+ axs[1, 0].set_title("Training Accuracy")
198
+
199
+ # Plot the test loss and accuracy for each epoch
200
+ axs[0, 1].plot(test_losses)
201
+ axs[0, 1].set_title("Test Loss")
202
+ axs[1, 1].plot(test_acc)
203
+ axs[1, 1].set_title("Test Accuracy")
204
+
205
+
206
+ # ---------------------------- Feature Maps and Kernels ----------------------------
207
+
208
+ @dataclass
209
+ class ConvLayerInfo:
210
+ """
211
+ Data Class to store Conv layer's information
212
+ """
213
+ layer_number: int
214
+ weights: torch.nn.parameter.Parameter
215
+ layer_info: torch.nn.modules.conv.Conv2d
216
+
217
+
218
+ class FeatureMapVisualizer:
219
+ """
220
+ Class to visualize Feature Map of the Layers
221
+ """
222
+
223
+ def __init__(self, model):
224
+ """
225
+ Contructor
226
+ :param model: Model Architecture
227
+ """
228
+ self.conv_layers = []
229
+ self.outputs = []
230
+ self.layerwise_kernels = None
231
+
232
+ # Disect the model
233
+ counter = 0
234
+ model_children = model.children()
235
+ for children in model_children:
236
+ if type(children) == nn.Sequential:
237
+ for child in children:
238
+ if type(child) == nn.Conv2d:
239
+ counter += 1
240
+ self.conv_layers.append(ConvLayerInfo(layer_number=counter,
241
+ weights=child.weight,
242
+ layer_info=child)
243
+ )
244
+
245
+ def get_model_weights(self):
246
+ """
247
+ Method to get the model weights
248
+ """
249
+ model_weights = [layer.weights for layer in self.conv_layers]
250
+ return model_weights
251
+
252
+ def get_conv_layers(self):
253
+ """
254
+ Get the convolution layers
255
+ """
256
+ conv_layers = [layer.layer_info for layer in self.conv_layers]
257
+ return conv_layers
258
+
259
+ def get_total_conv_layers(self) -> int:
260
+ """
261
+ Get total number of convolution layers
262
+ """
263
+ out = self.get_conv_layers()
264
+ return len(out)
265
+
266
+ def feature_maps_of_all_kernels(self, image: torch.Tensor) -> dict:
267
+ """
268
+ Get feature maps from all the kernels of all the layers
269
+ :param image: Image to be passed to the network
270
+ """
271
+ image = image.unsqueeze(0)
272
+ image = image.to('cpu')
273
+
274
+ outputs = {}
275
+
276
+ layers = self.get_conv_layers()
277
+ for index, layer in enumerate(layers):
278
+ image = layer(image)
279
+ outputs[str(layer)] = image
280
+ self.outputs = outputs
281
+ return outputs
282
+
283
+ def visualize_feature_map_of_kernel(self, image: torch.Tensor, kernel_number: int) -> None:
284
+ """
285
+ Function to visualize feature map of kernel number from each layer
286
+ :param image: Image passed to the network
287
+ :param kernel_number: Number of kernel in each layer (Should be less than or equal to the minimum number of kernel in the network)
288
+ """
289
+ # List to store processed feature maps
290
+ processed = []
291
+
292
+ # Get feature maps from all kernels of all the conv layers
293
+ outputs = self.feature_maps_of_all_kernels(image)
294
+
295
+ # Extract the n_th kernel's output from each layer and convert it to grayscale
296
+ for feature_map in outputs.values():
297
+ try:
298
+ feature_map = feature_map[0][kernel_number]
299
+ except IndexError:
300
+ print("Filter number should be less than the minimum number of channels in a network")
301
+ break
302
+ finally:
303
+ gray_scale = feature_map / feature_map.shape[0]
304
+ processed.append(gray_scale.data.numpy())
305
+
306
+ # Plot the Feature maps with layer and kernel number
307
+ x_range = len(outputs) // 5 + 4
308
+ fig = plt.figure(figsize=(10, 10))
309
+ for i in range(len(processed)):
310
+ a = fig.add_subplot(x_range, 5, i + 1)
311
+ imgplot = plt.imshow(processed[i])
312
+ a.axis("off")
313
+ title = f"{list(outputs.keys())[i].split('(')[0]}_l{i + 1}_k{kernel_number}"
314
+ a.set_title(title, fontsize=10)
315
+ return fig
316
+
317
+ def get_max_kernel_number(self):
318
+ """
319
+ Function to get maximum number of kernels in the network (for a layer)
320
+ """
321
+ layers = self.get_conv_layers()
322
+ channels = [layer.out_channels for layer in layers]
323
+ self.layerwise_kernels = channels
324
+ return max(channels)
325
+
326
+ def visualize_kernels_from_layer(self, layer_number: int):
327
+ """
328
+ Visualize Kernels from a layer
329
+ :param layer_number: Number of layer from which kernels are to be visualized
330
+ """
331
+ # Get the kernels number for each layer
332
+ self.get_max_kernel_number()
333
+
334
+ # Zero Indexing
335
+ layer_number = layer_number - 1
336
+ _kernels = self.layerwise_kernels[layer_number]
337
+
338
+ grid = math.ceil(math.sqrt(_kernels))
339
+
340
+ fig = plt.figure(figsize=(5, 4))
341
+ model_weights = self.get_model_weights()
342
+ _layer_weights = model_weights[layer_number].cpu()
343
+ for i, filter in enumerate(_layer_weights):
344
+ plt.subplot(grid, grid, i + 1)
345
+ plt.imshow(filter[0, :, :].detach(), cmap='gray')
346
+ plt.axis('off')
347
+ # plt.show()
348
+ return fig
349
+
350
+
351
+ # ---------------------------- Confusion Matrix ----------------------------
352
+ def visualize_confusion_matrix(classes: list[str], device: str, model: 'DL Model',
353
+ test_loader: torch.utils.data.DataLoader):
354
+ """
355
+ Function to generate and visualize confusion matrix
356
+ :param classes: List of class names
357
+ :param device: cuda/cpu
358
+ :param model: Model Architecture
359
+ :param test_loader: DataLoader for test set
360
+ """
361
+ nb_classes = len(classes)
362
+ device = 'cuda'
363
+ cm = torch.zeros(nb_classes, nb_classes)
364
+
365
+ model.eval()
366
+ with torch.no_grad():
367
+ for inputs, labels in test_loader:
368
+ inputs = inputs.to(device)
369
+ labels = labels.to(device)
370
+ model = model.to(device)
371
+
372
+ preds = model(inputs)
373
+ preds = preds.argmax(dim=1)
374
+
375
+ for t, p in zip(labels.view(-1), preds.view(-1)):
376
+ cm[t, p] = cm[t, p] + 1
377
+
378
+ # Build confusion matrix
379
+ labels = labels.to('cpu')
380
+ preds = preds.to('cpu')
381
+ cf_matrix = confusion_matrix(labels, preds)
382
+ df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None],
383
+ index=[i for i in classes],
384
+ columns=[i for i in classes])
385
+ plt.figure(figsize=(12, 7))
386
+ sn.heatmap(df_cm, annot=True)