Shilpaj commited on
Commit
6a6474f
·
1 Parent(s): fb36c8a

Upload resnet.py

Browse files
Files changed (1) hide show
  1. resnet.py +520 -0
resnet.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ PyTorch Lightning for ResNet Architecture
4
+ Author: Shilpaj Bhalerao
5
+ """
6
+ # Standard Library Imports
7
+ import os
8
+ import math
9
+
10
+ # Third-Party Imports
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import albumentations as A
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+ from torch.utils.data import DataLoader, random_split
20
+
21
+ from torchvision import transforms
22
+ from torchvision.datasets import CIFAR10
23
+
24
+ from pytorch_lightning import LightningModule, Trainer
25
+ from torchmetrics import Accuracy
26
+
27
+ # Local Imports
28
+ from datasets import AlbumDataset
29
+ from utils import get_cifar_statistics
30
+ from visualize import visualize_cifar_augmentation, display_cifar_data_samples
31
+
32
+
33
+ class Layers:
34
+ """
35
+ Class containing different types of Convolutional layer
36
+ """
37
+
38
+ def __init__(self, groups=1):
39
+ """
40
+ Constructor
41
+ """
42
+ self.group = groups
43
+
44
+ @staticmethod
45
+ def standard_conv_layer(in_channels: int,
46
+ out_channels: int,
47
+ kernel_size: int = 3,
48
+ padding: int = 0,
49
+ stride: int = 1,
50
+ dilation: int = 1,
51
+ normalization: str = "batch",
52
+ last_layer: bool = False,
53
+ conv_type: str = "standard",
54
+ groups: int = 1):
55
+ """
56
+ Method to return a standard convolution block
57
+ :param in_channels: Number of input channels
58
+ :param out_channels: Number of output channels
59
+ :param kernel_size: Size of the kernel used in the layer
60
+ :param padding: Padding used in the layer
61
+ :param stride: Stride used for convolution
62
+ :param dilation: Dilation for Atrous convolution
63
+ :param normalization: Type of normalization technique used
64
+ :param last_layer: Flag to indicate if the layer is last convolutional layer of the network
65
+ :param conv_type: Type of convolutional layer
66
+ :param groups: Number of Groups for Group Normalization
67
+ """
68
+ # Select normalization type
69
+ if normalization == "layer":
70
+ _norm_layer = nn.GroupNorm(1, out_channels)
71
+ elif normalization == "group":
72
+ if not groups:
73
+ raise ValueError("Value of group is not defined")
74
+ _norm_layer = nn.GroupNorm(groups, out_channels)
75
+ else:
76
+ _norm_layer = nn.BatchNorm2d(out_channels)
77
+
78
+ # Select the convolution layer type
79
+ if conv_type == "standard":
80
+ conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, stride=stride,
81
+ kernel_size=kernel_size, bias=False, padding=padding)
82
+ elif conv_type == "depthwise":
83
+ conv_layer = Layers.depthwise_conv(in_channels=in_channels, out_channels=out_channels, stride=stride,
84
+ padding=padding)
85
+ elif conv_type == "dilated":
86
+ conv_layer = Layers.dilated_conv(in_channels=in_channels, out_channels=out_channels, stride=stride,
87
+ padding=padding, dilation=dilation)
88
+
89
+ # For last layer only return the convolution output
90
+ if last_layer:
91
+ return nn.Sequential(conv_layer)
92
+ return nn.Sequential(
93
+ conv_layer,
94
+ _norm_layer,
95
+ nn.ReLU(),
96
+ # nn.Dropout(self.dropout_value)
97
+ )
98
+
99
+ @staticmethod
100
+ def resnet_block(channels):
101
+ """
102
+ Method to create a RESNET block
103
+ """
104
+ return nn.Sequential(
105
+ nn.Conv2d(in_channels=channels, out_channels=channels, stride=1, kernel_size=3, bias=False, padding=1),
106
+ nn.BatchNorm2d(channels),
107
+ nn.ReLU(),
108
+ nn.Conv2d(in_channels=channels, out_channels=channels, stride=1, kernel_size=3, bias=False, padding=1),
109
+ nn.BatchNorm2d(channels),
110
+ nn.ReLU(),
111
+ )
112
+
113
+ @staticmethod
114
+ def custom_block(input_channels, output_channels):
115
+ """
116
+ Method to create a custom configured block
117
+ :param input_channels: Number of input channels
118
+ :param output_channels: Number of output channels
119
+ """
120
+ return nn.Sequential(
121
+ nn.Conv2d(in_channels=input_channels, out_channels=output_channels, stride=1, kernel_size=3, bias=False,
122
+ padding=1),
123
+ nn.MaxPool2d(kernel_size=2, stride=2),
124
+ nn.BatchNorm2d(output_channels),
125
+ nn.ReLU(),
126
+ )
127
+
128
+ @staticmethod
129
+ def depthwise_conv(in_channels, out_channels, stride=1, padding=0):
130
+ """
131
+ Method to return the depthwise separable convolution layer
132
+ :param in_channels: Number of input channels
133
+ :param out_channels: Number of output channels
134
+ :param padding: Padding used in the layer
135
+ :param stride: Stride used for convolution
136
+ """
137
+ return nn.Sequential(
138
+ nn.Conv2d(in_channels=in_channels, out_channels=in_channels, stride=stride, groups=in_channels,
139
+ kernel_size=3, bias=False, padding=padding),
140
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=1, bias=False,
141
+ padding=0)
142
+ )
143
+
144
+ @staticmethod
145
+ def dilated_conv(in_channels, out_channels, stride=1, padding=0, dilation=1):
146
+ """
147
+ Method to return the dilated convolution layer
148
+ :param in_channels: Number of input channels
149
+ :param out_channels: Number of output channels
150
+ :param stride: Stride used for convolution
151
+ :param padding: Padding used in the layer
152
+ :param dilation: Dilation value for a kernel
153
+ """
154
+ return nn.Sequential(
155
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=3, bias=False,
156
+ padding=padding, dilation=dilation)
157
+ )
158
+
159
+
160
+ class LITResNet(LightningModule, Layers):
161
+ """
162
+ David's Model Architecture for Session-10 CIFAR10 dataset
163
+ """
164
+
165
+ def __init__(self, class_names, data_dir='/data/'):
166
+ """
167
+ Constructor
168
+ """
169
+ # Initialize the Module class
170
+ super().__init__()
171
+
172
+ # Initialize variables
173
+ self.classes = class_names
174
+ self.data_dir = data_dir
175
+ self.num_classes = 10
176
+ self._learning_rate = 0.03
177
+ self.inv_normalize = transforms.Normalize(
178
+ mean=[-0.50 / 0.23, -0.50 / 0.23, -0.50 / 0.23],
179
+ std=[1 / 0.23, 1 / 0.23, 1 / 0.23]
180
+ )
181
+ self.batch_size = 512
182
+ self.epochs = 24
183
+ self.accuracy = Accuracy(task='multiclass',
184
+ num_classes=10)
185
+ self.train_transforms = transforms.Compose([transforms.ToTensor()])
186
+ self.test_transforms = transforms.Compose([transforms.ToTensor()])
187
+ self.stats_train = None
188
+ self.stats_test = None
189
+ self.cifar10_train = None
190
+ self.cifar10_test = None
191
+ self.cifar10_val = None
192
+ self.misclassified_data = None
193
+
194
+ # Defined Layers for the model
195
+ self.prep_layer = None
196
+ self.custom_block1 = None
197
+ self.custom_block2 = None
198
+ self.custom_block3 = None
199
+ self.resnet_block1 = None
200
+ self.resnet_block3 = None
201
+ self.pool4 = None
202
+ self.fc = None
203
+ self.dropout_value = None
204
+
205
+ # Initialize all the layers
206
+ self.model_layers()
207
+
208
+ # ##################################################################################################
209
+ # ################################ Model Architecture Related Hooks ################################
210
+ # ##################################################################################################
211
+ def model_layers(self):
212
+ """
213
+ Method to initialize layers for the model
214
+ """
215
+ # Prep Layer
216
+ self.prep_layer = Layers.standard_conv_layer(in_channels=3, out_channels=64, kernel_size=3, padding=1, stride=1)
217
+
218
+ # Convolutional Block-1
219
+ self.custom_block1 = Layers.custom_block(input_channels=64, output_channels=128)
220
+ self.resnet_block1 = Layers.resnet_block(channels=128)
221
+
222
+ # Convolutional Block-2
223
+ self.custom_block2 = Layers.custom_block(input_channels=128, output_channels=256)
224
+
225
+ # Convolutional Block-3
226
+ self.custom_block3 = Layers.custom_block(input_channels=256, output_channels=512)
227
+ self.resnet_block3 = Layers.resnet_block(channels=512)
228
+
229
+ # MaxPool Layer
230
+ self.pool4 = nn.MaxPool2d(kernel_size=4, stride=2)
231
+
232
+ # Fully Connected Layer
233
+ self.fc = nn.Linear(in_features=512, out_features=10, bias=False)
234
+
235
+ # Dropout value of 10%
236
+ self.dropout_value = 0.1
237
+
238
+ def forward(self, x):
239
+ """
240
+ Forward pass for model training
241
+ :param x: Input layer
242
+ :return: Model Prediction
243
+ """
244
+ # Prep Layer
245
+ x = self.prep_layer(x)
246
+
247
+ # Convolutional Block-1
248
+ x = self.custom_block1(x)
249
+ r1 = self.resnet_block1(x)
250
+ x = x + r1
251
+
252
+ # Convolutional Block-2
253
+ x = self.custom_block2(x)
254
+
255
+ # Convolutional Block-3
256
+ x = self.custom_block3(x)
257
+ r2 = self.resnet_block3(x)
258
+ x = x + r2
259
+
260
+ # MaxPool Layer
261
+ x = self.pool4(x)
262
+
263
+ # Fully Connected Layer
264
+ x = x.view(-1, 512)
265
+ x = self.fc(x)
266
+
267
+ return F.log_softmax(x, dim=1)
268
+
269
+ # ##################################################################################################
270
+ # ############################## Training Configuration Related Hooks ##############################
271
+ # ##################################################################################################
272
+
273
+ def configure_optimizers(self):
274
+ """
275
+ Method to configure the optimizer and learning rate scheduler
276
+ """
277
+ learning_rate = 0.03
278
+ weight_decay = 1e-4
279
+ optimizer = optim.Adam(self.parameters(), lr=learning_rate, weight_decay=weight_decay)
280
+
281
+ # Scheduler
282
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
283
+ max_lr=self._learning_rate,
284
+ steps_per_epoch=len(self.train_dataloader()),
285
+ epochs=self.epochs,
286
+ pct_start=5 / self.epochs,
287
+ div_factor=100,
288
+ three_phase=False,
289
+ final_div_factor=100,
290
+ anneal_strategy="linear"
291
+ )
292
+ return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]
293
+
294
+ @property
295
+ def learning_rate(self) -> float:
296
+ """
297
+ Method to get the learning rate value
298
+ """
299
+ return self._learning_rate
300
+
301
+ @learning_rate.setter
302
+ def learning_rate(self, value: float):
303
+ """
304
+ Method to set the learning rate value
305
+ :param value: Updated value of learning rate
306
+ """
307
+ self._learning_rate = value
308
+
309
+ def set_training_confi(self, *, epochs, batch_size):
310
+ """
311
+ Method to set parameters required for model training
312
+ :param epochs: Number of epochs for which model is to be trained
313
+ :param batch_size: Batch Size
314
+ """
315
+ self.epochs = epochs
316
+ self.batch_size = batch_size
317
+
318
+ # #################################################################################################
319
+ # ################################## Training Loop Related Hooks ##################################
320
+ # #################################################################################################
321
+ def training_step(self, train_batch, batch_index):
322
+ """
323
+ Method called on training dataset to train the model
324
+ :param train_batch: Batch containing images and labels
325
+ :param batch_index: Index of the batch
326
+ """
327
+ x, y = train_batch
328
+ logits = self.forward(x)
329
+ loss = F.cross_entropy(logits, y)
330
+ preds = torch.argmax(logits, dim=1)
331
+ self.accuracy(preds, y)
332
+
333
+ self.log("train_loss", loss, prog_bar=True)
334
+ self.log("train_acc", self.accuracy, prog_bar=True)
335
+ return loss
336
+
337
+ def validation_step(self, batch, batch_idx):
338
+ """
339
+ Method called on validation dataset to check if the model is learning
340
+ :param batch: Batch containing images and labels
341
+ :param batch_idx: Index of the batch
342
+ """
343
+ x, y = batch
344
+ logits = self.forward(x)
345
+ loss = F.nll_loss(logits, y)
346
+ preds = torch.argmax(logits, dim=1)
347
+ self.accuracy(preds, y)
348
+
349
+ # Calling self.log will surface up scalars for you in TensorBoard
350
+ self.log("val_loss", loss, prog_bar=True)
351
+ self.log("val_acc", self.accuracy, prog_bar=True)
352
+ return loss
353
+
354
+ def test_step(self, batch, batch_idx):
355
+ """
356
+ Method called on test dataset to check model performance on unseen data
357
+ :param batch: Batch containing images and labels
358
+ :param batch_idx: Index of the batch
359
+ """
360
+ # Here we just reuse the validation_step for testing
361
+ return self.validation_step(batch, batch_idx)
362
+
363
+ # ##############################################################################################
364
+ # ##################################### Data Related Hooks #####################################
365
+ # ##############################################################################################
366
+
367
+ def set_transforms(self, train_set_transforms: dict, test_set_transforms: dict):
368
+ """
369
+ Method to set the transformations to be done on training and test datasets
370
+ :param train_set_transforms: Dictionary of transformations for training dataset
371
+ :param test_set_transforms: Dictionary of transformations for test dataset
372
+ """
373
+ self.train_transforms = A.Compose(train_set_transforms.values())
374
+ self.test_transforms = A.Compose(test_set_transforms.values())
375
+
376
+ def prepare_data(self):
377
+ """
378
+ Method to download the dataset
379
+ """
380
+ self.stats_train = CIFAR10('./data', train=True, download=True, transform=transforms.ToTensor())
381
+ self.stats_test = CIFAR10('./data', train=False, download=True, transform=transforms.ToTensor())
382
+
383
+ def setup(self, stage=None):
384
+ """
385
+ Method to create Split the dataset into train, test and val
386
+ """
387
+ # Only if dataset is not already split, perform the split operation
388
+ if not self.cifar10_train and not self.cifar10_test and not self.cifar10_val:
389
+
390
+ # Assign train/val datasets for use in dataloaders
391
+ if stage == "fit" or stage is None:
392
+ cifar10_full = AlbumDataset(self.data_dir, train=True, download=True, transform=self.train_transforms)
393
+ self.cifar10_train, self.cifar10_val = random_split(cifar10_full, [45_000, 5_000])
394
+
395
+ # Assign test dataset for use in dataloader(s)
396
+ if stage == "test" or stage is None:
397
+ self.cifar10_test = AlbumDataset(self.data_dir, train=False, download=True,
398
+ transform=self.test_transforms)
399
+
400
+ def train_dataloader(self):
401
+ """
402
+ Method to return the DataLoader for Training set
403
+ """
404
+ return DataLoader(self.cifar10_train, batch_size=self.batch_size, num_workers=os.cpu_count())
405
+
406
+ def val_dataloader(self):
407
+ """
408
+ Method to return the DataLoader for the Validation set
409
+ """
410
+ return DataLoader(self.cifar10_val, batch_size=self.batch_size, num_workers=os.cpu_count())
411
+
412
+ def test_dataloader(self):
413
+ """
414
+ Method to return the DataLoader for the Test set
415
+ """
416
+ return DataLoader(self.cifar10_test, batch_size=self.batch_size, num_workers=os.cpu_count())
417
+
418
+ def get_statistics(self, data_set_type="Train"):
419
+ """
420
+ Method to get the statistics for CIFAR10 dataset
421
+ """
422
+ # Execute self.prepare_data() only if not done earlier
423
+ if not self.stats_train and not self.stats_test:
424
+ self.prepare_data()
425
+
426
+ # Print stats for selected dataset
427
+ if data_set_type == "Train":
428
+ get_cifar_statistics(self.stats_train)
429
+ else:
430
+ get_cifar_statistics(self.stats_test, data_set_type="Test")
431
+
432
+ def display_data_samples(self, dataset="train", num_of_images=20):
433
+ """
434
+ Method to display data samples
435
+ """
436
+ # Execute self.prepare_data() only if not done earlier
437
+ try:
438
+ assert self.stats_train
439
+ except AttributeError:
440
+ self.prepare_data()
441
+
442
+ if dataset == "train":
443
+ display_cifar_data_samples(self.stats_train, num_of_images, self.classes)
444
+ else:
445
+ display_cifar_data_samples(self.stats_test, num_of_images, self.classes)
446
+
447
+ @staticmethod
448
+ def visualize_augmentation(aug_set_transforms: dict):
449
+ """
450
+ Method to visualize augmentations
451
+ :param aug_set_transforms: Dictionary of transformations to be visualized
452
+ """
453
+ aug_train = AlbumDataset('./data', train=True, download=True)
454
+ visualize_cifar_augmentation(aug_train, aug_set_transforms)
455
+
456
+ # #############################################################################################
457
+ # ############################## Misclassified Data Related Hooks ##############################
458
+ # #############################################################################################
459
+
460
+ def get_misclassified_data(self):
461
+ """
462
+ Function to run the model on test set and return misclassified images
463
+ """
464
+ if self.misclassified_data:
465
+ return self.misclassified_data
466
+
467
+ self.misclassified_data = []
468
+ self.prepare_data()
469
+ self.setup()
470
+
471
+ test_loader = self.test_dataloader()
472
+
473
+ # Reset the gradients
474
+ with torch.no_grad():
475
+ # Extract images, labels in a batch
476
+ for data, target in test_loader:
477
+
478
+ # Migrate the data to the device
479
+ data, target = data.to(self.device), target.to(self.device)
480
+
481
+ # Extract single image, label from the batch
482
+ for image, label in zip(data, target):
483
+
484
+ # Add batch dimension to the image
485
+ image = image.unsqueeze(0)
486
+
487
+ # Get the model prediction on the image
488
+ output = self.forward(image)
489
+
490
+ # Convert the output from one-hot encoding to a value
491
+ pred = output.argmax(dim=1, keepdim=True)
492
+
493
+ # If prediction is incorrect, append the data
494
+ if pred != label:
495
+ self.misclassified_data.append((image, label, pred))
496
+ return self.misclassified_data
497
+
498
+ def display_cifar_misclassified_data(self, number_of_samples: int = 10):
499
+ """
500
+ Function to plot images with labels
501
+ :param number_of_samples: Number of images to print
502
+ """
503
+ if not self.misclassified_data:
504
+ self.misclassified_data = self.get_misclassified_data()
505
+
506
+ fig = plt.figure(figsize=(10, 10))
507
+
508
+ x_count = 5
509
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
510
+
511
+ for i in range(number_of_samples):
512
+ plt.subplot(y_count, x_count, i + 1)
513
+ img = self.misclassified_data[i][0].squeeze().to('cpu')
514
+ img = self.inv_normalize(img)
515
+ plt.imshow(np.transpose(img, (1, 2, 0)))
516
+ plt.title(
517
+ r"Correct: " + self.classes[self.misclassified_data[i][1].item()] + '\n' + 'Output: ' + self.classes[
518
+ self.misclassified_data[i][2].item()])
519
+ plt.xticks([])
520
+ plt.yticks([])