nviraj commited on
Commit
ebb41db
1 Parent(s): 4a660fa

Added App files

Browse files
app.py CHANGED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Outline
2
+ # Import packages
3
+ # Import modules
4
+ # Constants
5
+ # Load model
6
+ # Function to process user uploaded image/ examples
7
+ # Inference function
8
+ # Gradio examples
9
+ # Gradio App
10
+
11
+ # Import packages required for the app
12
+ import gradio as gr
13
+
14
+ # Import custom modules
15
+ import modules.config as config
16
+ import numpy as np
17
+ import torch
18
+
19
+ # import torchvision
20
+ from modules.custom_resnet import CustomResNet
21
+ from modules.visualize import plot_gradcam_images, plot_misclassified_images
22
+ from pytorch_grad_cam import GradCAM
23
+ from pytorch_grad_cam.utils.image import show_cam_on_image
24
+ from torchvision import transforms
25
+
26
+ # Load and initialize the model
27
+ model = CustomResNet()
28
+
29
+ # Define device
30
+ cpu = torch.device("cpu")
31
+
32
+ # Using the checkpoint path present in config, load the trained model
33
+ model.load_state_dict(torch.load(config.MODEL_PATH, map_location=cpu), strict=False)
34
+ # Send model to CPU
35
+ model.to(cpu)
36
+ # Make the model in evaluation mode
37
+ model.eval()
38
+ print(f"Model Device: {next(model.parameters()).device}")
39
+
40
+
41
+ # Load the misclassified images data
42
+ misclassified_image_data = torch.load(config.MISCLASSIFIED_PATH, map_location=cpu)
43
+
44
+ # Class Names
45
+ classes = list(config.CIFAR_CLASSES)
46
+ # Allowed model names
47
+ model_layer_names = ["prep", "layer1_x", "layer1_r1", "layer2", "layer3_x", "layer3_r2"]
48
+
49
+
50
+ def get_target_layer(layer_name):
51
+ """Get target layer for visualization"""
52
+ if layer_name == "prep":
53
+ return [model.prep[-1]]
54
+ elif layer_name == "layer1_x":
55
+ return [model.layer1_x[-1]]
56
+ elif layer_name == "layer1_r1":
57
+ return [model.layer1_r1[-1]]
58
+ elif layer_name == "layer2":
59
+ return [model.layer2[-1]]
60
+ elif layer_name == "layer3_x":
61
+ return [model.layer3_x[-1]]
62
+ elif layer_name == "layer3_r2":
63
+ return [model.layer3_r2[-1]]
64
+ else:
65
+ return None
66
+
67
+
68
+ def generate_prediction(input_image, num_classes=3, show_gradcam=True, transparency=0.6, layer_name="layer3_x"):
69
+ """ "Given an input image, generate the prediction, confidence and visualization"""
70
+ mean = list(config.CIFAR_MEAN)
71
+ std = list(config.CIFAR_STD)
72
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
73
+
74
+ with torch.no_grad():
75
+ orginal_img = input_image
76
+ input_image = transform(input_image).unsqueeze(0).to(cpu)
77
+ print(f"Input Device: {input_image.device}")
78
+ outputs = model(input_image).to(cpu)
79
+ print(f"Output Device: {outputs.device}")
80
+ o = torch.exp(outputs).to(cpu)
81
+ print(f"Output Exp Device: {o.device}")
82
+
83
+ o_np = np.squeeze(np.asarray(o.numpy()))
84
+ # get indexes of probabilties in descending order
85
+ sorted_indexes = np.argsort(o_np)[::-1]
86
+ # sort the probabilities in descending order
87
+ final_class = classes[o_np.argmax()]
88
+
89
+ confidences = {}
90
+ for cnt in range(int(num_classes)):
91
+ # set the confidence of highest class with highest probability
92
+ confidences[classes[sorted_indexes[cnt]]] = float(o_np[sorted_indexes[cnt]])
93
+
94
+ # Show Grad Cam
95
+ if show_gradcam:
96
+ # Get the target layer
97
+ target_layers = get_target_layer(layer_name)
98
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
99
+ grayscale_cam = cam(input_tensor=input_image, targets=None)
100
+ grayscale_cam = grayscale_cam[0, :]
101
+ visualization = show_cam_on_image(orginal_img / 255, grayscale_cam, use_rgb=True, image_weight=transparency)
102
+ else:
103
+ visualization = orginal_img
104
+
105
+ return final_class, confidences, visualization
106
+
107
+
108
+ def app_interface(
109
+ input_image,
110
+ num_classes,
111
+ show_gradcam,
112
+ layer_name,
113
+ transparency,
114
+ show_misclassified,
115
+ num_misclassified,
116
+ show_gradcam_misclassified,
117
+ num_gradcam_misclassified,
118
+ ):
119
+ """Function which provides the Gradio interface"""
120
+
121
+ # Get the prediction for the input image along with confidence and visualization
122
+ final_class, confidences, visualization = generate_prediction(
123
+ input_image, num_classes, show_gradcam, transparency, layer_name
124
+ )
125
+
126
+ if show_misclassified:
127
+ misclassified_fig, misclassified_axs = plot_misclassified_images(
128
+ data=misclassified_image_data, class_label=classes, num_images=num_misclassified
129
+ )
130
+ else:
131
+ misclassified_fig = None
132
+
133
+ if show_gradcam_misclassified:
134
+ gradcam_fig, gradcam_axs = plot_gradcam_images(
135
+ model=model,
136
+ data=misclassified_image_data,
137
+ class_label=classes,
138
+ # Use penultimate block of resnet18 layer 3 as the target layer for gradcam
139
+ # Decided using model summary so that dimensions > 7x7
140
+ target_layers=get_target_layer(layer_name),
141
+ targets=None,
142
+ num_images=num_gradcam_misclassified,
143
+ image_weight=transparency,
144
+ )
145
+ else:
146
+ gradcam_fig = None
147
+
148
+ # # delete ununsed axises
149
+ # del misclassified_axs
150
+ # del gradcam_axs
151
+
152
+ return final_class, confidences, visualization, misclassified_fig, gradcam_fig
153
+
154
+
155
+ TITLE = "CIFAR10 Image classification using a Custom ResNet Model"
156
+ DESCRIPTION = "Gradio App to infer using a Custom ResNet model and get GradCAM results"
157
+ examples = [
158
+ ["assets/images/airplane.jpg", 3, True, "layer3_x", 0.6, True, 5, True, 5],
159
+ ["assets/images/bird.jpeg", 4, True, "layer3_x", 0.7, True, 10, True, 20],
160
+ ["assets/images/car.jpg", 5, True, "layer3_x", 0.5, True, 15, True, 5],
161
+ ["assets/images/cat.jpeg", 6, True, "layer3_x", 0.65, True, 20, True, 10],
162
+ ["assets/images/deer.jpg", 7, False, "layer2", 0.75, True, 5, True, 5],
163
+ ["assets/images/dog.jpg", 8, True, "layer2", 0.55, True, 10, True, 5],
164
+ ["assets/images/frog.jpeg", 9, True, "layer2", 0.8, True, 15, True, 15],
165
+ ["assets/images/horse.jpg", 10, False, "layer1_r1", 0.85, True, 20, True, 5],
166
+ ["assets/images/ship.jpg", 3, True, "layer1_r1", 0.4, True, 5, True, 15],
167
+ ["assets/images/truck.jpg", 4, True, "layer1_r1", 0.3, True, 5, True, 10],
168
+ ]
169
+ inference_app = gr.Interface(
170
+ app_interface,
171
+ inputs=[
172
+ # This accepts the image after resizing it to 32x32 which is what our model expects
173
+ gr.Image(shape=(32, 32)),
174
+ gr.Number(value=3, maximum=10, minimum=1, step=1.0, precision=0, label="#Classes to show"),
175
+ gr.Checkbox(True, label="Show GradCAM Image"),
176
+ gr.Dropdown(model_layer_names, value="layer3_x", label="Visulalization Layer from Model"),
177
+ # How much should the image be overlayed on the original image
178
+ gr.Slider(0, 1, 0.6, label="Image Overlay Factor"),
179
+ gr.Checkbox(True, label="Show Misclassified Images?"),
180
+ gr.Slider(value=10, maximum=25, minimum=5, step=5.0, precision=0, label="#Misclassified images to show"),
181
+ gr.Checkbox(True, label="Visulize GradCAM for Misclassified images?"),
182
+ gr.Slider(value=10, maximum=25, minimum=5, step=5.0, precision=0, label="#GradCAM images to show"),
183
+ ],
184
+ outputs=[
185
+ gr.Textbox(label="Top Class", container=True),
186
+ gr.Label(label="Confidences", container=True),
187
+ gr.Image(shape=(32, 32), label="Grad CAM/ Input Image", container=True).style(width=256, height=256),
188
+ gr.Plot(label="Misclassified images", container=True),
189
+ gr.Plot(label="Grad CAM of Misclassified images"),
190
+ ],
191
+ title=TITLE,
192
+ description=DESCRIPTION,
193
+ examples=examples,
194
+ )
195
+ inference_app.launch()
assets/images/airplane.jpg ADDED
assets/images/bird.jpeg ADDED
assets/images/car.jpg ADDED
assets/images/cat.jpeg ADDED
assets/images/deer.jpg ADDED
assets/images/dog.jpg ADDED
assets/images/frog.jpeg ADDED
assets/images/horse.jpg ADDED
assets/images/ship.jpg ADDED
assets/images/truck.jpg ADDED
assets/model/CustomResNet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5535c4904e58078bfd7ea91c78d0536a318006bf61e24fec575da0bd5656e791
3
+ size 26326547
assets/model/Misclassified_Data.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23e05b73fa387d4f3037d4a2c372615aac531f79361f029a9d2fae125ec575af
3
+ size 447578
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
modules/__pycache__/config.cpython-311.pyc ADDED
Binary file (966 Bytes). View file
 
modules/__pycache__/custom_resnet.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
modules/__pycache__/visualize.cpython-311.pyc ADDED
Binary file (7.84 kB). View file
 
modules/config.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Alert: Change these when running in production
2
+
3
+ # Constants naming convention: All caps separated by underscore
4
+ # https://realpython.com/python-constants/
5
+
6
+ # Where do we store the data?
7
+ MISCLASSIFIED_PATH = "./assets/model/Misclassified_Data.pt"
8
+ MODEL_PATH = "./assets/model/CustomResNet.pt"
9
+
10
+ # Set seed value for reproducibility
11
+ SEED = 53
12
+
13
+ # What is the mean and std deviation of the dataset?
14
+ CIFAR_MEAN = (0.4915, 0.4823, 0.4468)
15
+ CIFAR_STD = (0.2470, 0.2435, 0.2616)
16
+
17
+ # What are the classes in CIFAR10?
18
+ # Create class labels and convert to tuple
19
+ CIFAR_CLASSES = tuple(
20
+ c.capitalize()
21
+ for c in [
22
+ "plane",
23
+ "car",
24
+ "bird",
25
+ "cat",
26
+ "deer",
27
+ "dog",
28
+ "frog",
29
+ "horse",
30
+ "ship",
31
+ "truck",
32
+ ]
33
+ )
34
+
35
+ # Needed to load model module
36
+ # What is the start LR and weight decay you'd prefer?
37
+ PREFERRED_START_LR = 5e-3
38
+ PREFERRED_WEIGHT_DECAY = 1e-5
modules/custom_resnet.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module to define the model."""
2
+
3
+ # Resources
4
+ # https://lightning.ai/docs/pytorch/stable/starter/introduction.html
5
+ # https://lightning.ai/docs/pytorch/stable/starter/converting.html
6
+ # https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/cifar10-baseline.html
7
+
8
+ import modules.config as config
9
+ import pytorch_lightning as pl
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.optim as optim
14
+ import torchinfo
15
+ from torch.optim.lr_scheduler import OneCycleLR
16
+ from torch_lr_finder import LRFinder
17
+ from torchmetrics import Accuracy
18
+
19
+ # What is the start LR and weight decay you'd prefer?
20
+ PREFERRED_START_LR = config.PREFERRED_START_LR
21
+ PREFERRED_WEIGHT_DECAY = config.PREFERRED_WEIGHT_DECAY
22
+
23
+
24
+ def detailed_model_summary(model, input_size):
25
+ """Define a function to print the model summary."""
26
+
27
+ # https://github.com/TylerYep/torchinfo
28
+ torchinfo.summary(
29
+ model,
30
+ input_size=input_size,
31
+ batch_dim=0,
32
+ col_names=(
33
+ "input_size",
34
+ "kernel_size",
35
+ "output_size",
36
+ "num_params",
37
+ "trainable",
38
+ ),
39
+ verbose=1,
40
+ col_width=16,
41
+ )
42
+
43
+
44
+ ############# Assignment 12 Model #############
45
+
46
+
47
+ # This is for Assignment 12
48
+ # Model used from Assignment 10 and converted to lightning model
49
+ class CustomResNet(pl.LightningModule):
50
+ """This defines the structure of the NN."""
51
+
52
+ # Class variable to print shape
53
+ print_shape = False
54
+ # Default dropout value
55
+ dropout_value = 0.02
56
+
57
+ def __init__(self):
58
+ super().__init__()
59
+
60
+ # Define loss function
61
+ # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
62
+ self.loss_function = torch.nn.CrossEntropyLoss()
63
+
64
+ # Define accuracy function
65
+ # https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html
66
+ self.accuracy_function = Accuracy(task="multiclass", num_classes=10)
67
+
68
+ # Add results dictionary
69
+ self.results = {
70
+ "train_loss": [],
71
+ "train_acc": [],
72
+ "test_loss": [],
73
+ "test_acc": [],
74
+ "val_loss": [],
75
+ "val_acc": [],
76
+ }
77
+
78
+ # Save misclassified images
79
+ self.misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []}
80
+
81
+ # LR
82
+ self.learning_rate = PREFERRED_START_LR
83
+
84
+ # Model Notes
85
+
86
+ # PrepLayer - Conv 3x3 s1, p1) >> BN >> RELU [64k]
87
+ # 1. Input size: 32x32x3
88
+ self.prep = nn.Sequential(
89
+ nn.Conv2d(
90
+ in_channels=3,
91
+ out_channels=64,
92
+ kernel_size=(3, 3),
93
+ stride=1,
94
+ padding=1,
95
+ dilation=1,
96
+ bias=False,
97
+ ),
98
+ nn.BatchNorm2d(64),
99
+ nn.ReLU(),
100
+ nn.Dropout(self.dropout_value),
101
+ )
102
+
103
+ # Layer1: X = Conv 3x3 (s1, p1) >> MaxPool2D >> BN >> RELU [128k]
104
+ self.layer1_x = nn.Sequential(
105
+ nn.Conv2d(
106
+ in_channels=64,
107
+ out_channels=128,
108
+ kernel_size=(3, 3),
109
+ stride=1,
110
+ padding=1,
111
+ dilation=1,
112
+ bias=False,
113
+ ),
114
+ nn.MaxPool2d(kernel_size=2, stride=2),
115
+ nn.BatchNorm2d(128),
116
+ nn.ReLU(),
117
+ nn.Dropout(self.dropout_value),
118
+ )
119
+
120
+ # Layer1: R1 = ResBlock( (Conv-BN-ReLU-Conv-BN-ReLU))(X) [128k]
121
+ self.layer1_r1 = nn.Sequential(
122
+ nn.Conv2d(
123
+ in_channels=128,
124
+ out_channels=128,
125
+ kernel_size=(3, 3),
126
+ stride=1,
127
+ padding=1,
128
+ dilation=1,
129
+ bias=False,
130
+ ),
131
+ nn.BatchNorm2d(128),
132
+ nn.ReLU(),
133
+ nn.Dropout(self.dropout_value),
134
+ nn.Conv2d(
135
+ in_channels=128,
136
+ out_channels=128,
137
+ kernel_size=(3, 3),
138
+ stride=1,
139
+ padding=1,
140
+ dilation=1,
141
+ bias=False,
142
+ ),
143
+ nn.BatchNorm2d(128),
144
+ nn.ReLU(),
145
+ nn.Dropout(self.dropout_value),
146
+ )
147
+
148
+ # Layer 2: Conv 3x3 [256k], MaxPooling2D, BN, ReLU
149
+ self.layer2 = nn.Sequential(
150
+ nn.Conv2d(
151
+ in_channels=128,
152
+ out_channels=256,
153
+ kernel_size=(3, 3),
154
+ stride=1,
155
+ padding=1,
156
+ dilation=1,
157
+ bias=False,
158
+ ),
159
+ nn.MaxPool2d(kernel_size=2, stride=2),
160
+ nn.BatchNorm2d(256),
161
+ nn.ReLU(),
162
+ nn.Dropout(self.dropout_value),
163
+ )
164
+
165
+ # Layer 3: X = Conv 3x3 (s1, p1) >> MaxPool2D >> BN >> RELU [512k]
166
+ self.layer3_x = nn.Sequential(
167
+ nn.Conv2d(
168
+ in_channels=256,
169
+ out_channels=512,
170
+ kernel_size=(3, 3),
171
+ stride=1,
172
+ padding=1,
173
+ dilation=1,
174
+ bias=False,
175
+ ),
176
+ nn.MaxPool2d(kernel_size=2, stride=2),
177
+ nn.BatchNorm2d(512),
178
+ nn.ReLU(),
179
+ nn.Dropout(self.dropout_value),
180
+ )
181
+
182
+ # Layer 3: R2 = ResBlock( (Conv-BN-ReLU-Conv-BN-ReLU))(X) [512k]
183
+ self.layer3_r2 = nn.Sequential(
184
+ nn.Conv2d(
185
+ in_channels=512,
186
+ out_channels=512,
187
+ kernel_size=(3, 3),
188
+ stride=1,
189
+ padding=1,
190
+ dilation=1,
191
+ bias=False,
192
+ ),
193
+ nn.BatchNorm2d(512),
194
+ nn.ReLU(),
195
+ nn.Dropout(self.dropout_value),
196
+ nn.Conv2d(
197
+ in_channels=512,
198
+ out_channels=512,
199
+ kernel_size=(3, 3),
200
+ stride=1,
201
+ padding=1,
202
+ dilation=1,
203
+ bias=False,
204
+ ),
205
+ nn.BatchNorm2d(512),
206
+ nn.ReLU(),
207
+ nn.Dropout(self.dropout_value),
208
+ )
209
+
210
+ # MaxPooling with Kernel Size 4
211
+ # If stride is None, it is set to kernel_size
212
+ self.maxpool = nn.MaxPool2d(kernel_size=4, stride=4)
213
+
214
+ # FC Layer
215
+ self.fc = nn.Linear(512, 10)
216
+
217
+ # Save hyperparameters
218
+ self.save_hyperparameters()
219
+
220
+ def print_view(self, x, msg=""):
221
+ """Print shape of the model"""
222
+ if self.print_shape:
223
+ if msg != "":
224
+ print(msg, "\n\t", x.shape, "\n")
225
+ else:
226
+ print(x.shape)
227
+
228
+ def forward(self, x):
229
+ """Forward pass"""
230
+
231
+ # PrepLayer
232
+ x = self.prep(x)
233
+ self.print_view(x, "PrepLayer")
234
+
235
+ # Layer 1
236
+ x = self.layer1_x(x)
237
+ self.print_view(x, "Layer 1, X")
238
+ r1 = self.layer1_r1(x)
239
+ self.print_view(r1, "Layer 1, R1")
240
+ x = x + r1
241
+ self.print_view(x, "Layer 1, X + R1")
242
+
243
+ # Layer 2
244
+ x = self.layer2(x)
245
+ self.print_view(x, "Layer 2")
246
+
247
+ # Layer 3
248
+ x = self.layer3_x(x)
249
+ self.print_view(x, "Layer 3, X")
250
+ r2 = self.layer3_r2(x)
251
+ self.print_view(r2, "Layer 3, R2")
252
+ x = x + r2
253
+ self.print_view(x, "Layer 3, X + R2")
254
+
255
+ # MaxPooling
256
+ x = self.maxpool(x)
257
+ self.print_view(x, "Max Pooling")
258
+
259
+ # FC Layer
260
+ # Reshape before FC such that it becomes 1D
261
+ x = x.view(x.shape[0], -1)
262
+ self.print_view(x, "Reshape before FC")
263
+ x = self.fc(x)
264
+ self.print_view(x, "After FC")
265
+
266
+ # Softmax
267
+ return F.log_softmax(x, dim=-1)
268
+
269
+ # Alert: Remove this function later as Tuner is now being used to automatically find the best LR
270
+ def find_optimal_lr(self, train_loader):
271
+ """Use LR Finder to find the best starting learning rate"""
272
+
273
+ # https://github.com/davidtvs/pytorch-lr-finder
274
+ # https://github.com/davidtvs/pytorch-lr-finder#notes
275
+ # https://github.com/davidtvs/pytorch-lr-finder/blob/master/torch_lr_finder/lr_finder.py
276
+
277
+ # New optimizer with default LR
278
+ tmp_optimizer = optim.Adam(self.parameters(), lr=PREFERRED_START_LR, weight_decay=PREFERRED_WEIGHT_DECAY)
279
+
280
+ # Create LR finder object
281
+ lr_finder = LRFinder(self, optimizer=tmp_optimizer, criterion=self.loss_function)
282
+ lr_finder.range_test(train_loader=train_loader, end_lr=10, num_iter=100)
283
+ # https://github.com/davidtvs/pytorch-lr-finder/issues/88
284
+ _, suggested_lr = lr_finder.plot(suggest_lr=True)
285
+ lr_finder.reset()
286
+ # plot.figure.savefig("LRFinder - Suggested Max LR.png")
287
+
288
+ print(f"Suggested Max LR: {suggested_lr}")
289
+
290
+ if suggested_lr is None:
291
+ suggested_lr = PREFERRED_START_LR
292
+
293
+ return suggested_lr
294
+
295
+ # optimiser function
296
+ def configure_optimizers(self):
297
+ """Add ADAM optimizer to the lightning module"""
298
+ optimizer = optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=PREFERRED_WEIGHT_DECAY)
299
+
300
+ # Percent start for OneCycleLR
301
+ # Handles the case where max_epochs is less than 5
302
+ percent_start = 5 / int(self.trainer.max_epochs)
303
+ if percent_start >= 1:
304
+ percent_start = 0.3
305
+
306
+ # https://lightning.ai/docs/pytorch/stable/common/optimization.html#total-stepping-batches
307
+ scheduler_dict = {
308
+ "scheduler": OneCycleLR(
309
+ optimizer=optimizer,
310
+ max_lr=self.learning_rate,
311
+ total_steps=int(self.trainer.estimated_stepping_batches),
312
+ pct_start=percent_start,
313
+ div_factor=100,
314
+ three_phase=False,
315
+ anneal_strategy="linear",
316
+ final_div_factor=100,
317
+ verbose=False,
318
+ ),
319
+ "interval": "step",
320
+ }
321
+
322
+ return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
323
+
324
+ # Define loss function
325
+ def compute_loss(self, prediction, target):
326
+ """Compute Loss"""
327
+
328
+ # Calculate loss
329
+ loss = self.loss_function(prediction, target)
330
+
331
+ return loss
332
+
333
+ # Define accuracy function
334
+ def compute_accuracy(self, prediction, target):
335
+ """Compute accuracy"""
336
+
337
+ # Calculate accuracy
338
+ acc = self.accuracy_function(prediction, target)
339
+
340
+ return acc * 100
341
+
342
+ # Function to compute loss and accuracy for both training and validation
343
+ def compute_metrics(self, batch):
344
+ """Function to calculate loss and accuracy"""
345
+
346
+ # Get data and target from batch
347
+ data, target = batch
348
+
349
+ # Generate predictions using model
350
+ pred = self(data)
351
+
352
+ # Calculate loss for the batch
353
+ loss = self.compute_loss(prediction=pred, target=target)
354
+
355
+ # Calculate accuracy for the batch
356
+ acc = self.compute_accuracy(prediction=pred, target=target)
357
+
358
+ return loss, acc
359
+
360
+ # Get misclassified images based on how many images to return
361
+ def store_misclassified_images(self):
362
+ """Get an array of misclassified images"""
363
+
364
+ self.misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []}
365
+
366
+ # Initialize the model to evaluation mode
367
+ self.eval()
368
+
369
+ # Disable gradient calculation while testing
370
+ with torch.no_grad():
371
+ for batch in self.trainer.test_dataloaders:
372
+ # Move data and labels to device
373
+ data, target = batch
374
+ data, target = data.to(self.device), target.to(self.device)
375
+
376
+ # Predict using model
377
+ pred = self(data)
378
+
379
+ # Get the index of the max log-probability
380
+ output = pred.argmax(dim=1)
381
+
382
+ # Save the incorrect predictions
383
+ incorrect_indices = ~output.eq(target)
384
+
385
+ # Store images incorrectly predicted, generated predictions and the actual value
386
+ self.misclassified_image_data["images"].extend(data[incorrect_indices])
387
+ self.misclassified_image_data["ground_truths"].extend(target[incorrect_indices])
388
+ self.misclassified_image_data["predicted_vals"].extend(output[incorrect_indices])
389
+
390
+ # training function
391
+ def training_step(self, batch, batch_idx):
392
+ """Training step"""
393
+
394
+ # Compute loss and accuracy
395
+ loss, acc = self.compute_metrics(batch)
396
+
397
+ self.log("train_loss", loss, prog_bar=True, on_epoch=True, logger=True)
398
+ self.log("train_acc", acc, prog_bar=True, on_epoch=True, logger=True)
399
+ # Return training loss
400
+ return loss
401
+
402
+ # validation function
403
+ def validation_step(self, batch, batch_idx):
404
+ """Validation step"""
405
+
406
+ # Compute loss and accuracy
407
+ loss, acc = self.compute_metrics(batch)
408
+
409
+ self.log("val_loss", loss, prog_bar=True, on_epoch=True, logger=True)
410
+ self.log("val_acc", acc, prog_bar=True, on_epoch=True, logger=True)
411
+ # Return validation loss
412
+ return loss
413
+
414
+ # test function will just use validation step
415
+ def test_step(self, batch, batch_idx):
416
+ """Test step"""
417
+
418
+ # Compute loss and accuracy
419
+ loss, acc = self.compute_metrics(batch)
420
+
421
+ self.log("test_loss", loss, prog_bar=False, on_epoch=True, logger=True)
422
+ self.log("test_acc", acc, prog_bar=False, on_epoch=True, logger=True)
423
+ # Return validation loss
424
+ return loss
425
+
426
+ # At the end of train epoch append the training loss and accuracy to an instance variable called results
427
+ def on_train_epoch_end(self):
428
+ """On train epoch end"""
429
+
430
+ # Append training loss and accuracy to results
431
+ self.results["train_loss"].append(self.trainer.callback_metrics["train_loss"].detach().item())
432
+ self.results["train_acc"].append(self.trainer.callback_metrics["train_acc"].detach().item())
433
+
434
+ # At the end of validation epoch append the validation loss and accuracy to an instance variable called results
435
+ def on_validation_epoch_end(self):
436
+ """On validation epoch end"""
437
+
438
+ # Append validation loss and accuracy to results
439
+ self.results["test_loss"].append(self.trainer.callback_metrics["val_loss"].detach().item())
440
+ self.results["test_acc"].append(self.trainer.callback_metrics["val_acc"].detach().item())
441
+
442
+ # # At the end of test epoch append the test loss and accuracy to an instance variable called results
443
+ # def on_test_epoch_end(self):
444
+ # """On test epoch end"""
445
+
446
+ # # Append test loss and accuracy to results
447
+ # self.results["test_loss"].append(self.trainer.callback_metrics["test_loss"].detach().item())
448
+ # self.results["test_acc"].append(self.trainer.callback_metrics["test_acc"].detach().item())
449
+
450
+ # At the end of test save misclassified images, the predictions and ground truth in an instance variable called misclassified_image_data
451
+ def on_test_end(self):
452
+ """On test end"""
453
+
454
+ print("Test ended! Saving misclassified images")
455
+ # Get misclassified images
456
+ self.store_misclassified_images()
modules/visualize.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from pytorch_grad_cam import GradCAM
4
+ from pytorch_grad_cam.utils.image import show_cam_on_image
5
+
6
+
7
+ def convert_back_image(image):
8
+ """Using mean and std deviation convert image back to normal"""
9
+ cifar10_mean = (0.4914, 0.4822, 0.4471)
10
+ cifar10_std = (0.2469, 0.2433, 0.2615)
11
+ image = image.numpy().astype(dtype=np.float32)
12
+
13
+ for i in range(image.shape[0]):
14
+ image[i] = (image[i] * cifar10_std[i]) + cifar10_mean[i]
15
+
16
+ # To stop throwing a warning that image pixels exceeds bounds
17
+ image = image.clip(0, 1)
18
+
19
+ return np.transpose(image, (1, 2, 0))
20
+
21
+
22
+ def plot_sample_training_images(batch_data, batch_label, class_label, num_images=30):
23
+ """Function to plot sample images from the training data."""
24
+ images, labels = batch_data, batch_label
25
+
26
+ # Calculate the number of images to plot
27
+ num_images = min(num_images, len(images))
28
+ # calculate the number of rows and columns to plot
29
+ num_cols = 5
30
+ num_rows = int(np.ceil(num_images / num_cols))
31
+
32
+ # Initialize a subplot with the required number of rows and columns
33
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10))
34
+
35
+ # Iterate through the images and plot them in the grid along with class labels
36
+
37
+ for img_index in range(1, num_images + 1):
38
+ plt.subplot(num_rows, num_cols, img_index)
39
+ plt.tight_layout()
40
+ plt.axis("off")
41
+ plt.imshow(convert_back_image(images[img_index - 1]))
42
+ plt.title(class_label[labels[img_index - 1].item()])
43
+ plt.xticks([])
44
+ plt.yticks([])
45
+
46
+ return fig, axs
47
+
48
+
49
+ def plot_train_test_metrics(results):
50
+ """
51
+ Function to plot the training and test metrics.
52
+ """
53
+ # Extract train_losses, train_acc, test_losses, test_acc from results
54
+ train_losses = results["train_loss"]
55
+ train_acc = results["train_acc"]
56
+ test_losses = results["test_loss"]
57
+ test_acc = results["test_acc"]
58
+
59
+ # Plot the graphs in a 1x2 grid showing the training and test metrics
60
+ fig, axs = plt.subplots(1, 2, figsize=(16, 8))
61
+
62
+ # Loss plot
63
+ axs[0].plot(train_losses, label="Train")
64
+ axs[0].plot(test_losses, label="Test")
65
+ axs[0].set_title("Loss")
66
+ axs[0].legend(loc="upper right")
67
+
68
+ # Accuracy plot
69
+ axs[1].plot(train_acc, label="Train")
70
+ axs[1].plot(test_acc, label="Test")
71
+ axs[1].set_title("Accuracy")
72
+ axs[1].legend(loc="upper right")
73
+
74
+ return fig, axs
75
+
76
+
77
+ def plot_misclassified_images(data, class_label, num_images=10):
78
+ """Plot the misclassified images from the test dataset."""
79
+ # Calculate the number of images to plot
80
+ num_images = min(num_images, len(data["ground_truths"]))
81
+ # calculate the number of rows and columns to plot
82
+ num_cols = 5
83
+ num_rows = int(np.ceil(num_images / num_cols))
84
+
85
+ # Initialize a subplot with the required number of rows and columns
86
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2))
87
+
88
+ # Iterate through the images and plot them in the grid along with class labels
89
+
90
+ for img_index in range(1, num_images + 1):
91
+ # Get the ground truth and predicted labels for the image
92
+ label = data["ground_truths"][img_index - 1].cpu().item()
93
+ pred = data["predicted_vals"][img_index - 1].cpu().item()
94
+ # Get the image
95
+ image = data["images"][img_index - 1].cpu()
96
+ # Plot the image
97
+ plt.subplot(num_rows, num_cols, img_index)
98
+ plt.tight_layout()
99
+ plt.axis("off")
100
+ plt.imshow(convert_back_image(image))
101
+ plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""")
102
+ plt.xticks([])
103
+ plt.yticks([])
104
+
105
+ return fig, axs
106
+
107
+
108
+ # Function to plot gradcam for misclassified images using pytorch_grad_cam
109
+ def plot_gradcam_images(
110
+ model,
111
+ data,
112
+ class_label,
113
+ target_layers,
114
+ targets=None,
115
+ num_images=10,
116
+ image_weight=0.25,
117
+ ):
118
+ """Show gradcam for misclassified images"""
119
+
120
+ # Calculate the number of images to plot
121
+ num_images = min(num_images, len(data["ground_truths"]))
122
+ # calculate the number of rows and columns to plot
123
+ num_cols = 5
124
+ num_rows = int(np.ceil(num_images / num_cols))
125
+
126
+ # Initialize a subplot with the required number of rows and columns
127
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2))
128
+
129
+ # Initialize the GradCAM object
130
+ # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/grad_cam.py
131
+ # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/base_cam.py
132
+ # Alert: Change the device to cpu for gradio app
133
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
134
+
135
+ # Iterate through the images and plot them in the grid along with class labels
136
+ for img_index in range(1, num_images + 1):
137
+ # Extract elements from the data dictionary
138
+ # Get the ground truth and predicted labels for the image
139
+ label = data["ground_truths"][img_index - 1].cpu().item()
140
+ pred = data["predicted_vals"][img_index - 1].cpu().item()
141
+ # Get the image
142
+ image = data["images"][img_index - 1].cpu()
143
+
144
+ # Get the GradCAM output
145
+ # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/model_targets.py
146
+ grad_cam_output = cam(
147
+ input_tensor=image.unsqueeze(0),
148
+ targets=targets,
149
+ aug_smooth=True,
150
+ eigen_smooth=True,
151
+ )
152
+ grad_cam_output = grad_cam_output[0, :]
153
+
154
+ # Overlay gradcam on top of numpy image
155
+ overlayed_image = show_cam_on_image(
156
+ convert_back_image(image),
157
+ grad_cam_output,
158
+ use_rgb=True,
159
+ image_weight=image_weight,
160
+ )
161
+
162
+ # Plot the image
163
+ plt.subplot(num_rows, num_cols, img_index)
164
+ plt.tight_layout()
165
+ plt.axis("off")
166
+ plt.imshow(overlayed_image)
167
+ plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""")
168
+ plt.xticks([])
169
+ plt.yticks([])
170
+ return fig, axs
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations==1.3.1
2
+ grad-cam==1.4.8
3
+ gradio==3.39.0
4
+ numpy== 1.25.0
5
+ pillow==9.4.0
6
+ pytorch-lightning==2.0.6
7
+ pytorch==2.0.1
8
+ torch_lr_finder==0.2.1
9
+ torchinfo==1.8.0
10
+ torchmetrics==0.11.4
11
+ torchvision==0.15.2