soutrik commited on
Commit
29730dd
1 Parent(s): d200061

added: model and code and app

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/model/*.pt filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,80 @@
1
  ---
2
- title: ERAv2PytorchClassificationLightning
3
  emoji: 🔥
4
- colorFrom: pink
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.31.3
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Erav2s13
3
  emoji: 🔥
4
+ colorFrom: yellow
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.27.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Erav2s13- SOUTRIK 🔥
14
+
15
+ ## Overview
16
+ This repository leverages the Hugging Face repository and Gradio for building a user interface (UI). The model training was conducted using Google Colab, and the resulting model files are utilized for inference in the Gradio app.
17
+
18
+ - **Model Training**: `Main.ipynb` - Colab notebook used to build and train the model.
19
+ - **Inference**: The same model structure and files are used in the Gradio app.
20
+
21
+ ## Custom ResNet Model
22
+ The `custom_resnet.py` file defines a custom ResNet (Residual Network) model using PyTorch Lightning. This model is specifically designed for image classification tasks, particularly for the CIFAR-10 dataset.
23
+
24
+ ### Model Architecture
25
+ The custom ResNet model comprises the following components:
26
+
27
+ 1. **Preparation Layer**: Convolutional layer with 64 filters, followed by batch normalization, ReLU activation, and dropout.
28
+ 2. **Layer 1**: Convolutional layer with 128 filters, max pooling, batch normalization, ReLU activation, and dropout. Includes a residual block with two convolutional layers (128 filters each), batch normalization, ReLU activation, and dropout.
29
+ 3. **Layer 2**: Convolutional layer with 256 filters, max pooling, batch normalization, ReLU activation, and dropout.
30
+ 4. **Layer 3**: Convolutional layer with 512 filters, max pooling, batch normalization, ReLU activation, and dropout. Includes a residual block with two convolutional layers (512 filters each), batch normalization, ReLU activation, and dropout.
31
+ 5. **Max Pooling**: Max pooling layer with a kernel size of 4.
32
+ 6. **Fully Connected Layer**: Flattened output passed through a fully connected layer with 10 output units (for CIFAR-10 classes).
33
+ 7. **Softmax**: Log softmax activation function to obtain predicted class probabilities.
34
+
35
+ ### Training and Evaluation
36
+ The model is trained using PyTorch Lightning, which provides a high-level interface for training, validation, and testing. Key components include:
37
+
38
+ - **Optimizer**: Adam with a learning rate specified by `PREFERRED_START_LR`.
39
+ - **Scheduler**: OneCycleLR for learning rate adjustment.
40
+ - **Loss and Accuracy**: Cross-entropy loss and accuracy are computed and logged during training, validation, and testing.
41
+
42
+ ### Misclassified Images
43
+ During testing, misclassified images are tracked and stored in a dictionary along with their ground truth and predicted labels, facilitating error analysis and model improvement.
44
+
45
+ ### Hyperparameters
46
+ Key hyperparameters include:
47
+
48
+ - `PREFERRED_START_LR`: Initial learning rate.
49
+ - `PREFERRED_WEIGHT_DECAY`: Weight decay for regularization.
50
+
51
+ ### Model Summary
52
+ The `detailed_model_summary` function prints a comprehensive summary of the model architecture, detailing input size, kernel size, output size, number of parameters, and trainable status of each layer.
53
+
54
+ ## Lightning Dataset Module
55
+ The `lightning_dataset.py` file contains the `CIFARDataModule` class, which is a PyTorch Lightning `LightningDataModule` for the CIFAR-10 dataset. This class handles data preparation, splitting, and loading.
56
+
57
+ ### CIFARDataModule Class
58
+
59
+ #### Parameters
60
+ - `data_path`: Directory path for CIFAR-10 dataset.
61
+ - `batch_size`: Batch size for data loaders.
62
+ - `seed`: Random seed for reproducibility.
63
+ - `val_split`: Fraction of training data used for validation (default: 0).
64
+ - `num_workers`: Number of worker processes for data loading (default: 0).
65
+
66
+ #### Methods
67
+ - `prepare_data`: Downloads CIFAR-10 dataset if not present.
68
+ - `setup`: Defines data transformations and creates training, validation, and testing datasets.
69
+ - `train_dataloader`: Returns training data loader.
70
+ - `val_dataloader`: Returns validation data loader.
71
+ - `test_dataloader`: Returns testing data loader.
72
+
73
+ #### Utility Methods
74
+ - `_split_train_val`: Splits training dataset into training and validation subsets.
75
+ - `_init_fn`: Initializes random seed for each worker process to ensure reproducibility.
76
+
77
+ ## License
78
+ This project is licensed under the MIT License.
79
+
80
+ ---
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from models.custom_resnet import CustomResNet
4
+ from modules.visualize import plot_gradcam_images, plot_misclassified_images
5
+ from pytorch_grad_cam import GradCAM
6
+ from pytorch_grad_cam.utils.image import show_cam_on_image
7
+ from torchvision import transforms
8
+ import modules.config as config
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+
13
+
14
+ TITLE = "CIFAR10 Image classification using a Custom ResNet Model"
15
+ DESCRIPTION = "Gradio App to infer using a Custom ResNet model and get GradCAM results"
16
+ examples = [
17
+ ["assets/images/airplane.jpg", 3, True, "layer3_x", 0.6, True, 5, True, 5],
18
+ ["assets/images/bird.jpeg", 4, True, "layer3_x", 0.7, True, 10, True, 20],
19
+ ["assets/images/car.jpg", 5, True, "layer3_x", 0.5, True, 15, True, 5],
20
+ ["assets/images/cat.jpeg", 6, True, "layer3_x", 0.65, True, 20, True, 10],
21
+ ["assets/images/deer.jpg", 7, False, "layer2", 0.75, True, 5, True, 5],
22
+ ["assets/images/dog.jpg", 8, True, "layer2", 0.55, True, 10, True, 5],
23
+ ["assets/images/frog.jpeg", 9, True, "layer2", 0.8, True, 15, True, 15],
24
+ ["assets/images/horse.jpg", 10, False, "layer1_r1", 0.85, True, 20, True, 5],
25
+ ["assets/images/ship.jpg", 3, True, "layer1_r1", 0.4, True, 5, True, 15],
26
+ ["assets/images/truck.jpg", 4, True, "layer1_r1", 0.3, True, 5, True, 10],
27
+ ]
28
+
29
+
30
+ # load and initialise the model
31
+
32
+ model = CustomResNet()
33
+
34
+ # Define the device
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ # Using the checkpoint path present in config, load the trained model
37
+ model.load_state_dict(torch.load(config.GRADIO_MODEL_PATH, map_location=device), strict=False)
38
+ # Send model to CPU
39
+ model.to(device)
40
+ # Make the model in evaluation mode
41
+ model.eval()
42
+
43
+ # Load the misclassified images data
44
+ misclassified_image_data = torch.load(config.GRADIO_MISCLASSIFIED_PATH, map_location=device)
45
+
46
+ # Class Names
47
+ classes = list(config.CIFAR_CLASSES)
48
+ # Allowed model names
49
+ model_layer_names = ["prep", "layer1_x", "layer1_r1", "layer2", "layer3_x", "layer3_r2"]
50
+
51
+
52
+ def get_target_layer(layer_name):
53
+ """Get target layer for visualization"""
54
+ if layer_name == "prep":
55
+ return [model.prep[-1]]
56
+ elif layer_name == "layer1_x":
57
+ return [model.layer1_x[-1]]
58
+ elif layer_name == "layer1_r1":
59
+ return [model.layer1_r1[-1]]
60
+ elif layer_name == "layer2":
61
+ return [model.layer2[-1]]
62
+ elif layer_name == "layer3_x":
63
+ return [model.layer3_x[-1]]
64
+ elif layer_name == "layer3_r2":
65
+ return [model.layer3_r2[-1]]
66
+ else:
67
+ return None
68
+
69
+
70
+ def generate_prediction(input_image, num_classes=3, show_gradcam=True, transparency=0.6, layer_name="layer3_x"):
71
+ """ "Given an input image, generate the prediction, confidence and display_image"""
72
+ mean = list(config.CIFAR_MEAN)
73
+ std = list(config.CIFAR_STD)
74
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
75
+
76
+ with torch.no_grad():
77
+ orginal_img = input_image
78
+ input_image = transform(input_image).unsqueeze(0).to(device)
79
+ # print(f"Input Device: {input_image.device}")
80
+ model_output = model(input_image).to(device)
81
+ # print(f"Output Device: {outputs.device}")
82
+ output_exp = torch.exp(model_output).to(device)
83
+ # print(f"Output Exp Device: {o.device}")
84
+
85
+ output_numpy = np.squeeze(np.asarray(output_exp.numpy()))
86
+ # get indexes of probabilties in descending order
87
+ sorted_indexes = np.argsort(output_numpy)[::-1]
88
+ # sort the probabilities in descending order
89
+ # final_class = classes[o_np.argmax()]
90
+
91
+ confidences = {}
92
+ for _ in range(int(num_classes)):
93
+ # set the confidence of highest class with highest probability
94
+ confidences[classes[sorted_indexes[_]]] = float(output_numpy[sorted_indexes[_]])
95
+
96
+ # Show Grad Cam
97
+ if show_gradcam:
98
+ # Get the target layer
99
+ target_layers = get_target_layer(layer_name)
100
+ cam = GradCAM(model=model, target_layers=target_layers)
101
+ cam_generated = cam(input_tensor=input_image, targets=None)
102
+ cam_generated = cam_generated[0, :]
103
+ display_image = show_cam_on_image(orginal_img / 255, cam_generated, use_rgb=True, image_weight=transparency)
104
+
105
+ else:
106
+ display_image = orginal_img
107
+
108
+ return confidences, display_image
109
+
110
+
111
+ def app_interface(
112
+ input_image,
113
+ num_classes,
114
+ show_gradcam,
115
+ layer_name,
116
+ transparency,
117
+ show_misclassified,
118
+ num_misclassified,
119
+ show_gradcam_misclassified,
120
+ num_gradcam_misclassified,
121
+ ):
122
+ """Function which provides the Gradio interface"""
123
+ input_image = resize_image_pil(input_image, 32, 32)
124
+
125
+ input_image = np.array(input_image)
126
+ org_img = input_image
127
+ # Get the prediction for the input image along with confidence and display_image
128
+ confidences, display_image = generate_prediction(org_img, num_classes, show_gradcam, transparency, layer_name)
129
+
130
+ if show_misclassified:
131
+ misclassified_fig, misclassified_axs = plot_misclassified_images(
132
+ data=misclassified_image_data, class_label=classes, num_images=num_misclassified
133
+ )
134
+ else:
135
+ misclassified_fig = None
136
+
137
+ if show_gradcam_misclassified:
138
+ gradcam_fig, gradcam_axs = plot_gradcam_images(
139
+ model=model,
140
+ data=misclassified_image_data,
141
+ class_label=classes,
142
+ # Use penultimate block of resnet18 layer 3 as the target layer for gradcam
143
+ # Decided using model summary so that dimensions > 7x7
144
+ target_layers=get_target_layer(layer_name),
145
+ targets=None,
146
+ num_images=num_gradcam_misclassified,
147
+ image_weight=transparency,
148
+ )
149
+ else:
150
+ gradcam_fig = None
151
+
152
+ # # delete ununsed axises
153
+ # del misclassified_axs
154
+ # del gradcam_axs
155
+
156
+ return confidences, display_image, misclassified_fig, gradcam_fig
157
+
158
+ def resize_image_pil(image, new_width, new_height):
159
+
160
+ # Convert to PIL image
161
+ img = Image.fromarray(np.array(image))
162
+
163
+ # Get original size
164
+ width, height = img.size
165
+
166
+ # Calculate scale
167
+ width_scale = new_width / width
168
+ height_scale = new_height / height
169
+ scale = min(width_scale, height_scale)
170
+
171
+ # Resize
172
+ resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
173
+
174
+ # Crop to exact size
175
+ resized = resized.crop((0, 0, new_width, new_height))
176
+
177
+ return resized
178
+
179
+
180
+
181
+ inference_app = gr.Interface(
182
+ app_interface,
183
+ inputs=[
184
+ # This accepts the image after resizing it to 32x32 which is what our model expects
185
+ gr.Image(width=256, height=256, label="Input Image"),
186
+ gr.Number(value=3, maximum=10, minimum=1, step=1.0, precision=0, label="#Classes to show"),
187
+ gr.Checkbox(True, label="Show GradCAM Image"),
188
+ gr.Dropdown(model_layer_names, value="layer3_x", label="Visulalization Layer from Model"),
189
+ # How much should the image be overlayed on the original image
190
+ gr.Slider(0, 1, 0.6, label="Image Overlay Factor"),
191
+ gr.Checkbox(True, label="Show Misclassified Images?"),
192
+ gr.Slider(value=10, maximum=25, minimum=5, step=5.0, label="#Misclassified images to show"),
193
+ gr.Checkbox(True, label="Visulize GradCAM for Misclassified images?"),
194
+ gr.Slider(value=10, maximum=25, minimum=5, step=5.0, label="#GradCAM images to show"),
195
+ ],
196
+ outputs=[
197
+ gr.Label(label="Confidences", container=True, show_label=True),
198
+ gr.Image(label="Grad CAM/ Input Image", container=True, show_label=True,height=256,width=256),
199
+ gr.Plot(label="Misclassified images", container=True, show_label=True),
200
+ gr.Plot(label="Grad CAM of Misclassified images", container=True, show_label=True),
201
+ ],
202
+ title=TITLE,
203
+ description=DESCRIPTION,
204
+ examples=examples,
205
+ )
206
+ inference_app.launch()
207
+
208
+
images/airplane.jpg ADDED
images/bird.jpeg ADDED
images/car.jpg ADDED
images/cat.jpeg ADDED
images/deer.jpg ADDED
images/dog.jpg ADDED
images/frog.jpeg ADDED
images/horse.jpg ADDED
images/ship.jpg ADDED
images/truck.jpg ADDED
main.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def inverse_pic(input_img):
6
+ # print(type(input_img))
7
+ input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
8
+ return np.flip(input_img)
model/CustomResNet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe9f24bd3c056c0b4e4c7687678a941ebe0a51ab39ec6d83500ccc02ec2a6574
3
+ size 26326990
model/Misclassified_Data.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05b440a0bdf7f1996bbca47cd992360b9fb356f195fe7afe38f4ccd047463e58
3
+ size 485025
model/sample.txt ADDED
File without changes
models/__pycache__/custom_resnet.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
models/custom_resnet.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module to define the model."""
2
+
3
+ # Resources
4
+ # https://lightning.ai/docs/pytorch/stable/starter/introduction.html
5
+ # https://lightning.ai/docs/pytorch/stable/starter/converting.html
6
+ # https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/cifar10-baseline.html
7
+
8
+ import modules.config as config
9
+ import pytorch_lightning as pl
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.optim as optim
14
+ import torchinfo
15
+ from torch.optim.lr_scheduler import OneCycleLR
16
+ from torch_lr_finder import LRFinder
17
+ from torchmetrics import Accuracy
18
+
19
+ # What is the start LR and weight decay you'd prefer?
20
+ PREFERRED_START_LR = config.PREFERRED_START_LR
21
+ PREFERRED_WEIGHT_DECAY = config.PREFERRED_WEIGHT_DECAY
22
+
23
+
24
+ def detailed_model_summary(model, input_size):
25
+ """Define a function to print the model summary."""
26
+
27
+ # https://github.com/TylerYep/torchinfo
28
+ torchinfo.summary(
29
+ model,
30
+ input_size=input_size,
31
+ batch_dim=0,
32
+ col_names=(
33
+ "input_size",
34
+ "kernel_size",
35
+ "output_size",
36
+ "num_params",
37
+ "trainable",
38
+ ),
39
+ verbose=1,
40
+ col_width=16,
41
+ )
42
+
43
+
44
+ ############# Assignment 13 Model #############
45
+
46
+
47
+ # This is for Assignment 13
48
+ # Model used from Assignment 11 and converted to lightning model
49
+ class CustomResNet(pl.LightningModule):
50
+ """This defines the structure of the NN."""
51
+
52
+ # Class variable to print shape
53
+ print_shape = False
54
+ # Default dropout value
55
+ dropout_value = 0.02
56
+
57
+ def __init__(self):
58
+ super().__init__()
59
+
60
+ # Define loss function
61
+ # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
62
+ self.loss_function = torch.nn.CrossEntropyLoss()
63
+
64
+ # Define accuracy function
65
+ # https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html
66
+ self.accuracy_function = Accuracy(task="multiclass", num_classes=10)
67
+
68
+ # Add results dictionary
69
+ self.results = {
70
+ "train_loss": [],
71
+ "train_acc": [],
72
+ "test_loss": [],
73
+ "test_acc": [],
74
+ "val_loss": [],
75
+ "val_acc": [],
76
+ }
77
+
78
+ # Save misclassified images
79
+ self.misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []}
80
+
81
+ # LR
82
+ self.learning_rate = PREFERRED_START_LR
83
+
84
+ # Model Notes
85
+
86
+ # PrepLayer - Conv 3x3 s1, p1) >> BN >> RELU [64k]
87
+ # 1. Input size: 32x32x3
88
+ self.prep = nn.Sequential(
89
+ nn.Conv2d(
90
+ in_channels=3,
91
+ out_channels=64,
92
+ kernel_size=(3, 3),
93
+ stride=1,
94
+ padding=1,
95
+ dilation=1,
96
+ bias=False,
97
+ ),
98
+ nn.BatchNorm2d(64),
99
+ nn.ReLU(),
100
+ nn.Dropout(self.dropout_value),
101
+ )
102
+
103
+ # Layer1: X = Conv 3x3 (s1, p1) >> MaxPool2D >> BN >> RELU [128k]
104
+ self.layer1_x = nn.Sequential(
105
+ nn.Conv2d(
106
+ in_channels=64,
107
+ out_channels=128,
108
+ kernel_size=(3, 3),
109
+ stride=1,
110
+ padding=1,
111
+ dilation=1,
112
+ bias=False,
113
+ ),
114
+ nn.MaxPool2d(kernel_size=2, stride=2),
115
+ nn.BatchNorm2d(128),
116
+ nn.ReLU(),
117
+ nn.Dropout(self.dropout_value),
118
+ )
119
+
120
+ # Layer1: R1 = ResBlock( (Conv-BN-ReLU-Conv-BN-ReLU))(X) [128k]
121
+ self.layer1_r1 = nn.Sequential(
122
+ nn.Conv2d(
123
+ in_channels=128,
124
+ out_channels=128,
125
+ kernel_size=(3, 3),
126
+ stride=1,
127
+ padding=1,
128
+ dilation=1,
129
+ bias=False,
130
+ ),
131
+ nn.BatchNorm2d(128),
132
+ nn.ReLU(),
133
+ nn.Dropout(self.dropout_value),
134
+ nn.Conv2d(
135
+ in_channels=128,
136
+ out_channels=128,
137
+ kernel_size=(3, 3),
138
+ stride=1,
139
+ padding=1,
140
+ dilation=1,
141
+ bias=False,
142
+ ),
143
+ nn.BatchNorm2d(128),
144
+ nn.ReLU(),
145
+ nn.Dropout(self.dropout_value),
146
+ )
147
+
148
+ # Layer 2: Conv 3x3 [256k], MaxPooling2D, BN, ReLU
149
+ self.layer2 = nn.Sequential(
150
+ nn.Conv2d(
151
+ in_channels=128,
152
+ out_channels=256,
153
+ kernel_size=(3, 3),
154
+ stride=1,
155
+ padding=1,
156
+ dilation=1,
157
+ bias=False,
158
+ ),
159
+ nn.MaxPool2d(kernel_size=2, stride=2),
160
+ nn.BatchNorm2d(256),
161
+ nn.ReLU(),
162
+ nn.Dropout(self.dropout_value),
163
+ )
164
+
165
+ # Layer 3: X = Conv 3x3 (s1, p1) >> MaxPool2D >> BN >> RELU [512k]
166
+ self.layer3_x = nn.Sequential(
167
+ nn.Conv2d(
168
+ in_channels=256,
169
+ out_channels=512,
170
+ kernel_size=(3, 3),
171
+ stride=1,
172
+ padding=1,
173
+ dilation=1,
174
+ bias=False,
175
+ ),
176
+ nn.MaxPool2d(kernel_size=2, stride=2),
177
+ nn.BatchNorm2d(512),
178
+ nn.ReLU(),
179
+ nn.Dropout(self.dropout_value),
180
+ )
181
+
182
+ # Layer 3: R2 = ResBlock( (Conv-BN-ReLU-Conv-BN-ReLU))(X) [512k]
183
+ self.layer3_r2 = nn.Sequential(
184
+ nn.Conv2d(
185
+ in_channels=512,
186
+ out_channels=512,
187
+ kernel_size=(3, 3),
188
+ stride=1,
189
+ padding=1,
190
+ dilation=1,
191
+ bias=False,
192
+ ),
193
+ nn.BatchNorm2d(512),
194
+ nn.ReLU(),
195
+ nn.Dropout(self.dropout_value),
196
+ nn.Conv2d(
197
+ in_channels=512,
198
+ out_channels=512,
199
+ kernel_size=(3, 3),
200
+ stride=1,
201
+ padding=1,
202
+ dilation=1,
203
+ bias=False,
204
+ ),
205
+ nn.BatchNorm2d(512),
206
+ nn.ReLU(),
207
+ nn.Dropout(self.dropout_value),
208
+ )
209
+
210
+ # MaxPooling with Kernel Size 4
211
+ # If stride is None, it is set to kernel_size
212
+ self.maxpool = nn.MaxPool2d(kernel_size=4, stride=4)
213
+
214
+ # FC Layer
215
+ self.fc = nn.Linear(512, 10)
216
+
217
+ # Save hyperparameters
218
+ self.save_hyperparameters()
219
+
220
+ def print_view(self, x, msg=""):
221
+ """Print shape of the model"""
222
+ if self.print_shape:
223
+ if msg != "":
224
+ print(msg, "\n\t", x.shape, "\n")
225
+ else:
226
+ print(x.shape)
227
+
228
+ def forward(self, x):
229
+ """Forward pass"""
230
+
231
+ # PrepLayer
232
+ x = self.prep(x)
233
+ self.print_view(x, "PrepLayer")
234
+
235
+ # Layer 1
236
+ x = self.layer1_x(x)
237
+ self.print_view(x, "Layer 1, X")
238
+ r1 = self.layer1_r1(x)
239
+ self.print_view(r1, "Layer 1, R1")
240
+ x = x + r1
241
+ self.print_view(x, "Layer 1, X + R1")
242
+
243
+ # Layer 2
244
+ x = self.layer2(x)
245
+ self.print_view(x, "Layer 2")
246
+
247
+ # Layer 3
248
+ x = self.layer3_x(x)
249
+ self.print_view(x, "Layer 3, X")
250
+ r2 = self.layer3_r2(x)
251
+ self.print_view(r2, "Layer 3, R2")
252
+ x = x + r2
253
+ self.print_view(x, "Layer 3, X + R2")
254
+
255
+ # MaxPooling
256
+ x = self.maxpool(x)
257
+ self.print_view(x, "Max Pooling")
258
+
259
+ # FC Layer
260
+ # Reshape before FC such that it becomes 1D
261
+ x = x.view(x.shape[0], -1)
262
+ self.print_view(x, "Reshape before FC")
263
+ x = self.fc(x)
264
+ self.print_view(x, "After FC")
265
+
266
+ # Softmax
267
+ return F.log_softmax(x, dim=-1)
268
+
269
+ # Alert: Remove this function later as Tuner is now being used to automatically find the best LR
270
+ def find_optimal_lr(self, train_loader):
271
+ """Use LR Finder to find the best starting learning rate"""
272
+
273
+ # https://github.com/davidtvs/pytorch-lr-finder
274
+ # https://github.com/davidtvs/pytorch-lr-finder#notes
275
+ # https://github.com/davidtvs/pytorch-lr-finder/blob/master/torch_lr_finder/lr_finder.py
276
+
277
+ # New optimizer with default LR
278
+ tmp_optimizer = optim.Adam(self.parameters(), lr=PREFERRED_START_LR, weight_decay=PREFERRED_WEIGHT_DECAY)
279
+
280
+ # Create LR finder object
281
+ lr_finder = LRFinder(self, optimizer=tmp_optimizer, criterion=self.loss_function)
282
+ lr_finder.range_test(train_loader=train_loader, end_lr=10, num_iter=100)
283
+ # https://github.com/davidtvs/pytorch-lr-finder/issues/88
284
+ _, suggested_lr = lr_finder.plot(suggest_lr=True)
285
+ lr_finder.reset()
286
+ # plot.figure.savefig("LRFinder - Suggested Max LR.png")
287
+
288
+ print(f"Suggested Max LR: {suggested_lr}")
289
+
290
+ if suggested_lr is None:
291
+ suggested_lr = PREFERRED_START_LR
292
+
293
+ return suggested_lr
294
+
295
+ # optimiser function
296
+ def configure_optimizers(self):
297
+ """Add ADAM optimizer to the lightning module"""
298
+ optimizer = optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=PREFERRED_WEIGHT_DECAY)
299
+
300
+ # Percent start for OneCycleLR
301
+ # Handles the case where max_epochs is less than 5
302
+ percent_start = 5 / int(self.trainer.max_epochs)
303
+ if percent_start >= 1:
304
+ percent_start = 0.3
305
+
306
+ # https://lightning.ai/docs/pytorch/stable/common/optimization.html#total-stepping-batches
307
+ scheduler_dict = {
308
+ "scheduler": OneCycleLR(
309
+ optimizer=optimizer,
310
+ max_lr=self.learning_rate,
311
+ total_steps=int(self.trainer.estimated_stepping_batches),
312
+ pct_start=percent_start,
313
+ div_factor=100,
314
+ three_phase=False,
315
+ anneal_strategy="linear",
316
+ final_div_factor=100,
317
+ verbose=False,
318
+ ),
319
+ "interval": "step",
320
+ }
321
+
322
+ return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
323
+
324
+ # Define loss function
325
+ def compute_loss(self, prediction, target):
326
+ """Compute Loss"""
327
+
328
+ # Calculate loss
329
+ loss = self.loss_function(prediction, target)
330
+
331
+ return loss
332
+
333
+ # Define accuracy function
334
+ def compute_accuracy(self, prediction, target):
335
+ """Compute accuracy"""
336
+
337
+ # Calculate accuracy
338
+ acc = self.accuracy_function(prediction, target)
339
+
340
+ return acc * 100
341
+
342
+ # Function to compute loss and accuracy for both training and validation
343
+ def compute_metrics(self, batch):
344
+ """Function to calculate loss and accuracy"""
345
+
346
+ # Get data and target from batch
347
+ data, target = batch
348
+
349
+ # Generate predictions using model
350
+ pred = self(data)
351
+
352
+ # Calculate loss for the batch
353
+ loss = self.compute_loss(prediction=pred, target=target)
354
+
355
+ # Calculate accuracy for the batch
356
+ acc = self.compute_accuracy(prediction=pred, target=target)
357
+
358
+ return loss, acc
359
+
360
+ # Get misclassified images based on how many images to return
361
+ def store_misclassified_images(self):
362
+ """Get an array of misclassified images"""
363
+
364
+ self.misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []}
365
+
366
+ # Initialize the model to evaluation mode
367
+ self.eval()
368
+
369
+ # Disable gradient calculation while testing
370
+ with torch.no_grad():
371
+ for batch in self.trainer.test_dataloaders:
372
+ # Move data and labels to device
373
+ data, target = batch
374
+ data, target = data.to(self.device), target.to(self.device)
375
+
376
+ # Predict using model
377
+ pred = self(data)
378
+
379
+ # Get the index of the max log-probability
380
+ output = pred.argmax(dim=1)
381
+
382
+ # Save the incorrect predictions
383
+ incorrect_indices = ~output.eq(target)
384
+
385
+ # Store images incorrectly predicted, generated predictions and the actual value
386
+ self.misclassified_image_data["images"].extend(data[incorrect_indices])
387
+ self.misclassified_image_data["ground_truths"].extend(target[incorrect_indices])
388
+ self.misclassified_image_data["predicted_vals"].extend(output[incorrect_indices])
389
+
390
+ # training function
391
+ def training_step(self, batch, batch_idx):
392
+ """Training step"""
393
+
394
+ # Compute loss and accuracy
395
+ loss, acc = self.compute_metrics(batch)
396
+
397
+ self.log("train_loss", loss, prog_bar=True, on_epoch=True, logger=True)
398
+ self.log("train_acc", acc, prog_bar=True, on_epoch=True, logger=True)
399
+ # Return training loss
400
+ return loss
401
+
402
+ # validation function
403
+ def validation_step(self, batch, batch_idx):
404
+ """Validation step"""
405
+
406
+ # Compute loss and accuracy
407
+ loss, acc = self.compute_metrics(batch)
408
+
409
+ self.log("val_loss", loss, prog_bar=True, on_epoch=True, logger=True)
410
+ self.log("val_acc", acc, prog_bar=True, on_epoch=True, logger=True)
411
+ # Return validation loss
412
+ return loss
413
+
414
+ # test function will just use validation step
415
+ def test_step(self, batch, batch_idx):
416
+ """Test step"""
417
+
418
+ # Compute loss and accuracy
419
+ loss, acc = self.compute_metrics(batch)
420
+
421
+ self.log("test_loss", loss, prog_bar=False, on_epoch=True, logger=True)
422
+ self.log("test_acc", acc, prog_bar=False, on_epoch=True, logger=True)
423
+ # Return validation loss
424
+ return loss
425
+
426
+ # At the end of train epoch append the training loss and accuracy to an instance variable called results
427
+ def on_train_epoch_end(self):
428
+ """On train epoch end"""
429
+
430
+ # Append training loss and accuracy to results
431
+ self.results["train_loss"].append(self.trainer.callback_metrics["train_loss"].detach().item())
432
+ self.results["train_acc"].append(self.trainer.callback_metrics["train_acc"].detach().item())
433
+
434
+ # At the end of validation epoch append the validation loss and accuracy to an instance variable called results
435
+ def on_validation_epoch_end(self):
436
+ """On validation epoch end"""
437
+
438
+ # Append validation loss and accuracy to results
439
+ self.results["test_loss"].append(self.trainer.callback_metrics["val_loss"].detach().item())
440
+ self.results["test_acc"].append(self.trainer.callback_metrics["val_acc"].detach().item())
441
+
442
+ # # At the end of test epoch append the test loss and accuracy to an instance variable called results
443
+ # def on_test_epoch_end(self):
444
+ # """On test epoch end"""
445
+
446
+ # # Append test loss and accuracy to results
447
+ # self.results["test_loss"].append(self.trainer.callback_metrics["test_loss"].detach().item())
448
+ # self.results["test_acc"].append(self.trainer.callback_metrics["test_acc"].detach().item())
449
+
450
+ # At the end of test save misclassified images, the predictions and ground truth in an instance variable called misclassified_image_data
451
+ def on_test_end(self):
452
+ """On test end"""
453
+
454
+ print("Test ended! Saving misclassified images")
455
+ # Get misclassified images
456
+ self.store_misclassified_images()
modules/__pycache__/config.cpython-311.pyc ADDED
Binary file (1.22 kB). View file
 
modules/config.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Alert: Change these when running in production
2
+
3
+ # Constants naming convention: All caps separated by underscore
4
+ # https://realpython.com/python-constants/
5
+
6
+ # Where do we store the data?
7
+ DATA_PATH = "../../data/"
8
+ CHECKPOINT_PATH = "../../checkpoints/"
9
+ LOGGING_PATH = "../../logs/"
10
+ MISCLASSIFIED_PATH = "../../Misclassified_Data.pt"
11
+ MODEL_PATH = "../../CustomResNet.pt"
12
+
13
+ # Specify the number of epochs
14
+ NUM_EPOCHS = 24
15
+
16
+ # Set the batch size
17
+ BATCH_SIZE = 512
18
+
19
+ # Set seed value for reproducibility
20
+ SEED = 53
21
+
22
+ # What is the start LR and weight decay you'd prefer?
23
+ PREFERRED_START_LR = 5e-3
24
+ PREFERRED_WEIGHT_DECAY = 1e-5
25
+
26
+
27
+ # What is the mean and std deviation of the dataset?
28
+ CIFAR_MEAN = (0.4915, 0.4823, 0.4468)
29
+ CIFAR_STD = (0.2470, 0.2435, 0.2616)
30
+
31
+ # What is the cutout size?
32
+ CUTOUT_SIZE = 16
33
+
34
+ # What are the classes in CIFAR10?
35
+ # Create class labels and convert to tuple
36
+ CIFAR_CLASSES = tuple(
37
+ c.capitalize()
38
+ for c in [
39
+ "plane",
40
+ "car",
41
+ "bird",
42
+ "cat",
43
+ "deer",
44
+ "dog",
45
+ "frog",
46
+ "horse",
47
+ "ship",
48
+ "truck",
49
+ ]
50
+ )
51
+
52
+
53
+ GRADIO_MISCLASSIFIED_PATH = "./assets/model/Misclassified_Data.pt"
54
+ GRADIO_MODEL_PATH = "./assets/model/CustomResNet.pt"
modules/dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains functions to download and transform the CIFAR10 dataset"""
2
+ # Needed for image transformations
3
+ import albumentations as A
4
+ import modules.config as config
5
+
6
+ # # Needed for padding issues in albumentations
7
+ # import cv2
8
+ import numpy as np
9
+ from albumentations.pytorch.transforms import ToTensorV2
10
+ from torch.utils.data import Dataset
11
+
12
+ # Use precomputed values for mean and standard deviation of the dataset
13
+ CIFAR_MEAN = config.CIFAR_MEAN
14
+ CIFAR_STD = config.CIFAR_STD
15
+ CUTOUT_SIZE = config.CUTOUT_SIZE
16
+
17
+ # Create class labels and convert to tuple
18
+ CIFAR_CLASSES = config.CIFAR_CLASSES
19
+
20
+
21
+ class CIFAR10Transforms(Dataset):
22
+ """Apply albumentations augmentations to CIFAR10 dataset"""
23
+
24
+ # Given a dataset and transformations,
25
+ # apply the transformations and return the dataset
26
+ def __init__(self, dataset, transforms):
27
+ self.dataset = dataset
28
+ self.transforms = transforms
29
+
30
+ def __getitem__(self, idx):
31
+ # Get the image and label from the dataset
32
+ image, label = self.dataset[idx]
33
+
34
+ # Apply transformations on the image
35
+ image = self.transforms(image=np.array(image))["image"]
36
+
37
+ return image, label
38
+
39
+ def __len__(self):
40
+ return len(self.dataset)
41
+
42
+ def __repr__(self):
43
+ return f"CIFAR10Transforms(dataset={self.dataset}, transforms={self.transforms})"
44
+
45
+ def __str__(self):
46
+ return f"CIFAR10Transforms(dataset={self.dataset}, transforms={self.transforms})"
47
+
48
+
49
+ def apply_cifar_image_transformations(mean=CIFAR_MEAN, std=CIFAR_STD, cutout_size=CUTOUT_SIZE):
50
+ """
51
+ Function to apply the required transformations to the MNIST dataset.
52
+ """
53
+ # Apply the required transformations to the MNIST dataset
54
+ train_transforms = A.Compose(
55
+ [
56
+ # normalize the images with mean and standard deviation from the whole dataset
57
+ # https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize
58
+ # # transforms.Normalize(cifar_mean, cifar_std),
59
+ A.Normalize(mean=list(mean), std=list(std)),
60
+ # RandomCrop 32, 32 (after padding of 4)
61
+ # https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.PadIfNeeded
62
+ # MinHeight and MinWidth are set to 36 to ensure that the image is padded to 36x36 after padding
63
+ # border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
64
+ # cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
65
+ # Default: cv2.BORDER_REFLECT_101
66
+ A.PadIfNeeded(min_height=36, min_width=36),
67
+ # https://albumentations.ai/docs/api_reference/augmentations/crops/transforms/#albumentations.augmentations.crops.transforms.RandomCrop
68
+ A.RandomCrop(32, 32),
69
+ # CutOut(8, 8)
70
+ # # https://albumentations.ai/docs/api_reference/augmentations/dropout/cutout/#albumentations.augmentations.dropout.cutout.Cutout
71
+ # # Because we normalized the images with mean and standard deviation from the whole dataset, the fill_value is set to the mean of the dataset
72
+ # A.Cutout(
73
+ # num_holes=1, max_h_size=cutout_size, max_w_size=cutout_size, p=1.0
74
+ # ),
75
+ # https://albumentations.ai/docs/api_reference/augmentations/dropout/coarse_dropout/#coarsedropout-augmentation-augmentationsdropoutcoarse_dropout
76
+ A.CoarseDropout(
77
+ max_holes=1,
78
+ max_height=cutout_size,
79
+ max_width=cutout_size,
80
+ min_holes=1,
81
+ min_height=cutout_size,
82
+ min_width=cutout_size,
83
+ p=1.0,
84
+ ),
85
+ # Convert the images to tensors
86
+ # # transforms.ToTensor(),
87
+ ToTensorV2(),
88
+ ]
89
+ )
90
+
91
+ # Test data transformations
92
+ test_transforms = A.Compose(
93
+ # Convert the images to tensors
94
+ # normalize the images with mean and standard deviation from the whole dataset
95
+ [
96
+ A.Normalize(mean=list(mean), std=list(std)),
97
+ # Convert the images to tensors
98
+ ToTensorV2(),
99
+ ]
100
+ )
101
+
102
+ return train_transforms, test_transforms
103
+
104
+
105
+ def calculate_mean_std(dataset):
106
+ """Function to calculate the mean and standard deviation of CIFAR dataset"""
107
+ data = dataset.data.astype(np.float32) / 255.0
108
+ mean = np.mean(data, axis=(0, 1, 2))
109
+ std = np.std(data, axis=(0, 1, 2))
110
+ return mean, std
modules/lightning_dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains functions to prepare dataloader in the way lightning expects"""
2
+ import pytorch_lightning as pl
3
+ import torchvision.datasets as datasets
4
+ from lightning_fabric.utilities.seed import seed_everything
5
+ from modules.dataset import CIFAR10Transforms, apply_cifar_image_transformations
6
+ from torch.utils.data import DataLoader, random_split
7
+
8
+
9
+ class CIFARDataModule(pl.LightningDataModule):
10
+ """Lightning DataModule for CIFAR10 dataset"""
11
+
12
+ def __init__(self, data_path, batch_size, seed, val_split=0, num_workers=0):
13
+ super().__init__()
14
+
15
+ self.data_path = data_path
16
+ self.batch_size = batch_size
17
+ self.seed = seed
18
+ self.val_split = val_split
19
+ self.num_workers = num_workers
20
+ self.dataloader_dict = {
21
+ # "shuffle": True,
22
+ "batch_size": self.batch_size,
23
+ "num_workers": self.num_workers,
24
+ "pin_memory": True,
25
+ # "worker_init_fn": self._init_fn,
26
+ "persistent_workers": self.num_workers > 0,
27
+ }
28
+ self.prepare_data_per_node = False
29
+
30
+ # Fixes attribute defined outside __init__ warning
31
+ self.training_dataset = None
32
+ self.validation_dataset = None
33
+ self.testing_dataset = None
34
+
35
+ # # Make sure data is downloaded
36
+ # self.prepare_data()
37
+
38
+ def _split_train_val(self, dataset):
39
+ """Split the dataset into train and validation sets"""
40
+
41
+ # Throw an error if the validation split is not between 0 and 1
42
+ if not 0 < self.val_split < 1:
43
+ raise ValueError("Validation split must be between 0 and 1")
44
+
45
+ # # Set seed again, might not be necessary
46
+ # seed_everything(int(self.seed))
47
+
48
+ # Calculate lengths of each dataset
49
+ total_length = len(dataset)
50
+ train_length = int((1 - self.val_split) * total_length)
51
+ val_length = total_length - train_length
52
+
53
+ # Split the dataset
54
+ train_dataset, val_dataset = random_split(dataset, [train_length, val_length])
55
+
56
+ return train_dataset, val_dataset
57
+
58
+ # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data
59
+ def prepare_data(self):
60
+ # Download the CIFAR10 dataset if it doesn't exist
61
+ datasets.CIFAR10(self.data_path, train=True, download=True)
62
+ datasets.CIFAR10(self.data_path, train=False, download=True)
63
+
64
+ # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#setup
65
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.hooks.DataHooks.html#lightning.pytorch.core.hooks.DataHooks.setup
66
+ def setup(self, stage=None):
67
+ # seed_everything(int(self.seed))
68
+
69
+ # Define the data transformations
70
+ train_transforms, test_transforms = apply_cifar_image_transformations()
71
+ val_transforms = test_transforms
72
+
73
+ # Create train and validation datasets
74
+ if stage == "fit" or stage is None:
75
+ if self.val_split != 0:
76
+ # Split the training data into training and validation sets
77
+ data_train, data_val = self._split_train_val(datasets.CIFAR10(self.data_path, train=True))
78
+ # Apply transformations
79
+ self.training_dataset = CIFAR10Transforms(data_train, train_transforms)
80
+ self.validation_dataset = CIFAR10Transforms(data_val, val_transforms)
81
+ else:
82
+ # Only training data here
83
+ self.training_dataset = CIFAR10Transforms(
84
+ datasets.CIFAR10(self.data_path, train=True), train_transforms
85
+ )
86
+ # Validation will be same sa test
87
+ self.validation_dataset = CIFAR10Transforms(
88
+ datasets.CIFAR10(self.data_path, train=False), val_transforms
89
+ )
90
+
91
+ # Create test dataset
92
+ if stage == "test" or stage is None:
93
+ # Assign Test split(s) for use in Dataloaders
94
+ self.testing_dataset = CIFAR10Transforms(datasets.CIFAR10(self.data_path, train=False), test_transforms)
95
+
96
+ # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#train-dataloader
97
+ def train_dataloader(self):
98
+ return DataLoader(self.training_dataset, **self.dataloader_dict, shuffle=True)
99
+
100
+ # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#val-dataloader
101
+ def val_dataloader(self):
102
+ return DataLoader(self.validation_dataset, **self.dataloader_dict, shuffle=False)
103
+
104
+ # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#test-dataloader
105
+ def test_dataloader(self):
106
+ return DataLoader(self.testing_dataset, **self.dataloader_dict, shuffle=False)
107
+
108
+ def _init_fn(self, worker_id):
109
+ seed_everything(int(self.seed) + worker_id)
modules/trainer.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module to define the train and test functions."""
2
+
3
+ # from functools import partial
4
+
5
+ import modules.config as config
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from modules.utils import create_folder_if_not_exists
9
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, ModelSummary
10
+
11
+ # Import tuner
12
+ from pytorch_lightning.tuner.tuning import Tuner
13
+
14
+ # What is the start LR and weight decay you'd prefer?
15
+ PREFERRED_START_LR = config.PREFERRED_START_LR
16
+
17
+
18
+ def train_and_test_model(
19
+ batch_size,
20
+ num_epochs,
21
+ model,
22
+ datamodule,
23
+ logger,
24
+ debug=False,
25
+ ):
26
+ """Trains and tests the model by iterating through epochs using Lightning Trainer."""
27
+
28
+ print(f"\n\nBatch size: {batch_size}, Total epochs: {num_epochs}\n\n")
29
+
30
+ print("Defining Lightning Callbacks")
31
+
32
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
33
+ checkpoint = ModelCheckpoint(
34
+ dirpath=config.CHECKPOINT_PATH, monitor="val_acc", mode="max", filename="model_best_epoch", save_last=True
35
+ )
36
+ # # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html#learningratemonitor
37
+ lr_rate_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=False)
38
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelSummary.html#lightning.pytorch.callbacks.ModelSummary
39
+ model_summary = ModelSummary(max_depth=0)
40
+
41
+ print("Defining Lightning Trainer")
42
+ # Change trainer settings for debugging
43
+ if debug:
44
+ num_epochs = 1
45
+ fast_dev_run = True
46
+ overfit_batches = 0.1
47
+ profiler = "advanced"
48
+ else:
49
+ fast_dev_run = False
50
+ overfit_batches = 0.0
51
+ profiler = None
52
+
53
+ # https://lightning.ai/docs/pytorch/stable/common/trainer.html#methods
54
+ trainer = pl.Trainer(
55
+ precision=16,
56
+ fast_dev_run=fast_dev_run,
57
+ # deterministic=True,
58
+ # devices="auto",
59
+ # accelerator="auto",
60
+ max_epochs=num_epochs,
61
+ logger=logger,
62
+ # enable_model_summary=False,
63
+ overfit_batches=overfit_batches,
64
+ log_every_n_steps=10,
65
+ # num_sanity_val_steps=5,
66
+ profiler=profiler,
67
+ # check_val_every_n_epoch=1,
68
+ callbacks=[checkpoint, lr_rate_monitor, model_summary],
69
+ # callbacks=[checkpoint],
70
+ )
71
+
72
+ # # Using the learning rate finder
73
+ # model.learning_rate = model.find_optimal_lr(train_loader=datamodule.train_dataloader())
74
+
75
+ # Using the lr_find from Trainer.tune method instead
76
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.tuner.tuning.Tuner.html#lightning.pytorch.tuner.tuning.Tuner
77
+ # https://www.youtube.com/watch?v=cLZv0eZQSIE
78
+ print("Finding the optimal learning rate using Lightning Tuner.")
79
+ tuner = Tuner(trainer)
80
+ tuner.lr_find(
81
+ model=model,
82
+ datamodule=datamodule,
83
+ min_lr=PREFERRED_START_LR,
84
+ max_lr=5,
85
+ num_training=200,
86
+ mode="linear",
87
+ early_stop_threshold=10,
88
+ attr_name="learning_rate",
89
+ )
90
+
91
+ trainer.fit(model, datamodule=datamodule)
92
+ trainer.test(model, dataloaders=datamodule.test_dataloader())
93
+
94
+ # # Obtain the results dictionary from model
95
+ print("Collecting epoch level model results.")
96
+ results = model.results
97
+ # print(f"Results Length: {len(results)}")
98
+
99
+ # Get the list of misclassified images
100
+ print("Collecting misclassified images.")
101
+ misclassified_image_data = model.misclassified_image_data
102
+ # print(f"Misclassified Images Length: {len(misclassified_image_data)}")
103
+
104
+ # Save the model using torch save as backup
105
+ print("Saving the model.")
106
+ print(f"Model saved to {config.MODEL_PATH}")
107
+ create_folder_if_not_exists(config.MODEL_PATH)
108
+ torch.save(model.state_dict(), config.MODEL_PATH)
109
+
110
+ # Save first few misclassified images data to a file
111
+ num_elements = 20
112
+ print(f"Saving first {num_elements} misclassified images.")
113
+ subset_misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []}
114
+ subset_misclassified_image_data["images"] = misclassified_image_data["images"][:num_elements]
115
+ subset_misclassified_image_data["ground_truths"] = misclassified_image_data["ground_truths"][:num_elements]
116
+ subset_misclassified_image_data["predicted_vals"] = misclassified_image_data["predicted_vals"][:num_elements]
117
+ create_folder_if_not_exists(config.MISCLASSIFIED_PATH)
118
+ torch.save(subset_misclassified_image_data, config.MISCLASSIFIED_PATH)
119
+
120
+ return trainer, results, misclassified_image_data
121
+ # return trainer
modules/utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module to define utility functions for the project."""
2
+ import os
3
+
4
+ import torch
5
+
6
+
7
+ def get_num_workers(model_run_location):
8
+ """Given a run mode, return the number of workers to be used for data loading."""
9
+
10
+ # calculate the number of workers
11
+ num_workers = (os.cpu_count() - 1) if os.cpu_count() > 3 else 2
12
+
13
+ # If run_mode is local, use only 2 workers
14
+ num_workers = num_workers if model_run_location == "colab" else 0
15
+
16
+ return num_workers
17
+
18
+
19
+ # Function to save the model
20
+ # https://debuggercafe.com/saving-and-loading-the-best-model-in-pytorch/
21
+ def save_model(epoch, model, optimizer, scheduler, batch_size, criterion, file_name):
22
+ """
23
+ Function to save the trained model along with other information to disk.
24
+ """
25
+ # print(f"Saving model from epoch {epoch}...")
26
+ torch.save(
27
+ {
28
+ "epoch": epoch,
29
+ "model_state_dict": model.state_dict(),
30
+ "optimizer_state_dict": optimizer.state_dict(),
31
+ "scheduler_state_dict": scheduler.state_dict(),
32
+ "batch_size": batch_size,
33
+ "loss": criterion,
34
+ },
35
+ file_name,
36
+ )
37
+
38
+
39
+ # Given a list of train_losses, train_accuracies, test_losses,
40
+ # test_accuracies, loop through epoch and print the metrics
41
+ def pretty_print_metrics(num_epochs, results):
42
+ """
43
+ Function to print the metrics in a pretty format.
44
+ """
45
+ # Extract train_losses, train_acc, test_losses, test_acc from results
46
+ train_losses = results["train_loss"]
47
+ train_acc = results["train_acc"]
48
+ test_losses = results["test_loss"]
49
+ test_acc = results["test_acc"]
50
+
51
+ for i in range(num_epochs):
52
+ print(
53
+ f"Epoch: {i+1:02d}, Train Loss: {train_losses[i]:.4f}, "
54
+ f"Test Loss: {test_losses[i]:.4f}, Train Accuracy: {train_acc[i]:.4f}, "
55
+ f"Test Accuracy: {test_acc[i]:.4f}"
56
+ )
57
+
58
+
59
+ # Given a file path, extract the folder path and create folder recursively if it does not already exist
60
+ def create_folder_if_not_exists(file_path):
61
+ """
62
+ Function to create a folder if it does not exist.
63
+ """
64
+ # Extract the folder path
65
+ folder_path = os.path.dirname(file_path)
66
+ print(f"Folder path: {folder_path}")
67
+
68
+ # Create the folder if it does not exist
69
+ if not os.path.exists(folder_path):
70
+ os.makedirs(folder_path,exist_ok=True)
71
+ print(f"Created folder: {folder_path}")
modules/visualize.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from pytorch_grad_cam import GradCAM
4
+ from pytorch_grad_cam.utils.image import show_cam_on_image
5
+
6
+
7
+ def convert_back_image(image):
8
+ """Using mean and std deviation convert image back to normal"""
9
+ cifar10_mean = (0.4914, 0.4822, 0.4471)
10
+ cifar10_std = (0.2469, 0.2433, 0.2615)
11
+ image = image.numpy().astype(dtype=np.float32)
12
+
13
+ for i in range(image.shape[0]):
14
+ image[i] = (image[i] * cifar10_std[i]) + cifar10_mean[i]
15
+
16
+ # To stop throwing a warning that image pixels exceeds bounds
17
+ image = image.clip(0, 1)
18
+
19
+ return np.transpose(image, (1, 2, 0))
20
+
21
+
22
+ def plot_sample_training_images(batch_data, batch_label, class_label, num_images=30):
23
+ """Function to plot sample images from the training data."""
24
+ images, labels = batch_data, batch_label
25
+
26
+ # Calculate the number of images to plot
27
+ num_images = min(num_images, len(images))
28
+ # calculate the number of rows and columns to plot
29
+ num_cols = 5
30
+ num_rows = int(np.ceil(num_images / num_cols))
31
+
32
+ # Initialize a subplot with the required number of rows and columns
33
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10))
34
+
35
+ # Iterate through the images and plot them in the grid along with class labels
36
+
37
+ for img_index in range(1, num_images + 1):
38
+ plt.subplot(num_rows, num_cols, img_index)
39
+ plt.tight_layout()
40
+ plt.axis("off")
41
+ plt.imshow(convert_back_image(images[img_index - 1]))
42
+ plt.title(class_label[labels[img_index - 1].item()])
43
+ plt.xticks([])
44
+ plt.yticks([])
45
+
46
+ return fig, axs
47
+
48
+
49
+ def plot_train_test_metrics(results):
50
+ """
51
+ Function to plot the training and test metrics.
52
+ """
53
+ # Extract train_losses, train_acc, test_losses, test_acc from results
54
+ train_losses = results["train_loss"]
55
+ train_acc = results["train_acc"]
56
+ test_losses = results["test_loss"]
57
+ test_acc = results["test_acc"]
58
+
59
+ # Plot the graphs in a 1x2 grid showing the training and test metrics
60
+ fig, axs = plt.subplots(1, 2, figsize=(16, 8))
61
+
62
+ # Loss plot
63
+ axs[0].plot(train_losses, label="Train")
64
+ axs[0].plot(test_losses, label="Test")
65
+ axs[0].set_title("Loss")
66
+ axs[0].legend(loc="upper right")
67
+
68
+ # Accuracy plot
69
+ axs[1].plot(train_acc, label="Train")
70
+ axs[1].plot(test_acc, label="Test")
71
+ axs[1].set_title("Accuracy")
72
+ axs[1].legend(loc="upper right")
73
+
74
+ return fig, axs
75
+
76
+
77
+ def plot_misclassified_images(data, class_label, num_images=10):
78
+ """Plot the misclassified images from the test dataset."""
79
+ # Calculate the number of images to plot
80
+ num_images = min(num_images, len(data["ground_truths"]))
81
+ # calculate the number of rows and columns to plot
82
+ num_cols = 5
83
+ num_rows = int(np.ceil(num_images / num_cols))
84
+
85
+ # Initialize a subplot with the required number of rows and columns
86
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2))
87
+
88
+ # Iterate through the images and plot them in the grid along with class labels
89
+
90
+ for img_index in range(1, num_images + 1):
91
+ # Get the ground truth and predicted labels for the image
92
+ label = data["ground_truths"][img_index - 1].cpu().item()
93
+ pred = data["predicted_vals"][img_index - 1].cpu().item()
94
+ # Get the image
95
+ image = data["images"][img_index - 1].cpu()
96
+ # Plot the image
97
+ plt.subplot(num_rows, num_cols, img_index)
98
+ plt.tight_layout()
99
+ plt.axis("off")
100
+ plt.imshow(convert_back_image(image))
101
+ plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""")
102
+ plt.xticks([])
103
+ plt.yticks([])
104
+
105
+ return fig, axs
106
+
107
+
108
+ # Function to plot gradcam for misclassified images using pytorch_grad_cam
109
+ def plot_gradcam_images(
110
+ model,
111
+ data,
112
+ class_label,
113
+ target_layers,
114
+ targets=None,
115
+ num_images=10,
116
+ image_weight=0.25,
117
+ ):
118
+ """Show gradcam for misclassified images"""
119
+
120
+ # Calculate the number of images to plot
121
+ num_images = min(num_images, len(data["ground_truths"]))
122
+ # calculate the number of rows and columns to plot
123
+ num_cols = 5
124
+ num_rows = int(np.ceil(num_images / num_cols))
125
+
126
+ # Initialize a subplot with the required number of rows and columns
127
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2))
128
+
129
+ # Initialize the GradCAM object
130
+ # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/grad_cam.py
131
+ # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/base_cam.py
132
+ cam = GradCAM(model=model, target_layers=target_layers)
133
+
134
+ # Iterate through the images and plot them in the grid along with class labels
135
+ for img_index in range(1, num_images + 1):
136
+ # Extract elements from the data dictionary
137
+ # Get the ground truth and predicted labels for the image
138
+ label = data["ground_truths"][img_index - 1].cpu().item()
139
+ pred = data["predicted_vals"][img_index - 1].cpu().item()
140
+ # Get the image
141
+ image = data["images"][img_index - 1].cpu()
142
+
143
+ # Get the GradCAM output
144
+ # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/model_targets.py
145
+ grad_cam_output = cam(
146
+ input_tensor=image.unsqueeze(0),
147
+ targets=targets,
148
+ aug_smooth=True,
149
+ eigen_smooth=True,
150
+ )
151
+ grad_cam_output = grad_cam_output[0, :]
152
+
153
+ # Overlay gradcam on top of numpy image
154
+ overlayed_image = show_cam_on_image(
155
+ convert_back_image(image),
156
+ grad_cam_output,
157
+ use_rgb=True,
158
+ image_weight=image_weight,
159
+ )
160
+
161
+ # Plot the image
162
+ plt.subplot(num_rows, num_cols, img_index)
163
+ plt.tight_layout()
164
+ plt.axis("off")
165
+ plt.imshow(overlayed_image)
166
+ plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""")
167
+ plt.xticks([])
168
+ plt.yticks([])
169
+ return fig, axs
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations==1.3.1
2
+ grad-cam==1.5.0
3
+ gradio==3.39.0
4
+ numpy== 1.25.0
5
+ pillow==9.4.0
6
+ pytorch-lightning==2.0.6
7
+ torch_lr_finder==0.2.1
8
+ torch==2.0.1
9
+ torchinfo==1.8.0
10
+ torchmetrics==0.11.4
11
+ torchvision==0.15.2