bhimrazy commited on
Commit
c118196
1 Parent(s): e5d6e03

Refactors dataset and datamodules

Browse files
Files changed (2) hide show
  1. src/data_module.py +117 -0
  2. src/dataset.py +1 -79
src/data_module.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ import numpy as np
3
+ import torch
4
+ from sklearn.utils.class_weight import compute_class_weight
5
+ from torch.utils.data import DataLoader, WeightedRandomSampler
6
+ from torchvision.transforms import v2 as T
7
+
8
+ from src.dataset import DRDataset
9
+
10
+
11
+ class DRDataModule(L.LightningDataModule):
12
+ def __init__(
13
+ self,
14
+ train_csv_path,
15
+ val_csv_path,
16
+ image_size: int = 224,
17
+ batch_size: int = 8,
18
+ num_workers: int = 4,
19
+ use_class_weighting: bool = False,
20
+ use_weighted_sampler: bool = False,
21
+ ):
22
+ super().__init__()
23
+ self.batch_size = batch_size
24
+ self.num_workers = num_workers
25
+
26
+ # Ensure mutual exclusivity between use_class_weighting and use_weighted_sampler
27
+ if use_class_weighting and use_weighted_sampler:
28
+ raise ValueError(
29
+ "use_class_weighting and use_weighted_sampler cannot both be True"
30
+ )
31
+
32
+ self.train_csv_path = train_csv_path
33
+ self.val_csv_path = val_csv_path
34
+ self.use_class_weighting = use_class_weighting
35
+ self.use_weighted_sampler = use_weighted_sampler
36
+
37
+ # Define the transformations
38
+ self.train_transform = T.Compose(
39
+ [
40
+ T.Resize((image_size, image_size), antialias=True),
41
+ T.RandomAffine(degrees=10, translate=(0.01, 0.01), scale=(0.99, 1.01)),
42
+ T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.01),
43
+ T.RandomHorizontalFlip(p=0.5),
44
+ T.ToDtype(torch.float32, scale=True),
45
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
46
+ ]
47
+ )
48
+
49
+ self.val_transform = T.Compose(
50
+ [
51
+ T.Resize((image_size, image_size), antialias=True),
52
+ T.ToDtype(torch.float32, scale=True),
53
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
54
+ ]
55
+ )
56
+
57
+ def setup(self, stage=None):
58
+ """Set up datasets for training and validation."""
59
+ # Initialize datasets with specified transformations
60
+ self.train_dataset = DRDataset(
61
+ self.train_csv_path, transform=self.train_transform
62
+ )
63
+ self.val_dataset = DRDataset(self.val_csv_path, transform=self.val_transform)
64
+
65
+ # Compute number of classes and class weights
66
+ labels = self.train_dataset.labels.numpy()
67
+ self.num_classes = len(np.unique(labels))
68
+ self.class_weights = (
69
+ self._compute_class_weights(labels) if self.use_class_weighting else None
70
+ )
71
+
72
+ def train_dataloader(self):
73
+ """Returns a DataLoader for training data."""
74
+ if self.use_weighted_sampler:
75
+ sampler = self._get_weighted_sampler(self.train_dataset.labels.numpy())
76
+ shuffle = False # Sampler will handle shuffling
77
+ else:
78
+ sampler = None
79
+ shuffle = True
80
+
81
+ return DataLoader(
82
+ self.train_dataset,
83
+ batch_size=self.batch_size,
84
+ sampler=sampler,
85
+ shuffle=shuffle,
86
+ num_workers=self.num_workers,
87
+ )
88
+
89
+ def val_dataloader(self):
90
+ return DataLoader(
91
+ self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers
92
+ )
93
+
94
+ def _compute_class_weights(self, labels):
95
+ class_weights = compute_class_weight(
96
+ class_weight="balanced", classes=np.unique(labels), y=labels
97
+ )
98
+ return torch.tensor(class_weights, dtype=torch.float32)
99
+
100
+ def _get_weighted_sampler(self, labels: np.ndarray) -> WeightedRandomSampler:
101
+ """Returns a WeightedRandomSampler based on class weights.
102
+
103
+ The weights tensor should contain a weight for each sample, not the class weights.
104
+ Have a look at this post for an example: https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2
105
+ https://www.maskaravivek.com/post/pytorch-weighted-random-sampler/
106
+ """
107
+
108
+ class_sample_count = np.array(
109
+ [len(np.where(labels == label)[0]) for label in np.unique(labels)]
110
+ )
111
+ weight = 1.0 / class_sample_count
112
+ samples_weight = np.array([weight[label] for label in labels])
113
+ samples_weight = torch.from_numpy(samples_weight)
114
+
115
+ return WeightedRandomSampler(
116
+ weights=samples_weight, num_samples=len(labels), replacement=True
117
+ )
src/dataset.py CHANGED
@@ -1,13 +1,9 @@
1
  import os
2
 
3
- import lightning as L
4
- import numpy as np
5
  import pandas as pd
6
  import torch
7
- from sklearn.utils.class_weight import compute_class_weight
8
- from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
9
  from torchvision.io import read_image
10
- from torchvision.transforms import v2 as T
11
 
12
 
13
  class DRDataset(Dataset):
@@ -68,77 +64,3 @@ class DRDataset(Dataset):
68
 
69
  return image, label
70
 
71
-
72
- class DRDataModule(L.LightningDataModule):
73
- def __init__(self, batch_size: int = 8, num_workers: int = 4):
74
- super().__init__()
75
- self.batch_size = batch_size
76
- self.num_workers = num_workers
77
-
78
- # Define the transformations
79
- self.train_transform = T.Compose(
80
- [
81
- T.Resize((224, 224), antialias=True),
82
- T.RandomAffine(degrees=10, translate=(0.01, 0.01), scale=(0.99, 1.01)),
83
- T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.01),
84
- T.RandomHorizontalFlip(p=0.5),
85
- T.ToDtype(torch.float32, scale=True),
86
- T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
87
- ]
88
- )
89
-
90
- self.val_transform = T.Compose(
91
- [
92
- T.Resize((224, 224), antialias=True),
93
- T.ToDtype(torch.float32, scale=True),
94
- T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
95
- ]
96
- )
97
-
98
- self.num_classes = 5
99
-
100
- def setup(self, stage=None):
101
- self.train_dataset = DRDataset("data/train.csv", transform=self.train_transform)
102
- self.val_dataset = DRDataset("data/val.csv", transform=self.val_transform)
103
-
104
- # compute class weights
105
- labels = self.train_dataset.labels.numpy()
106
- self.class_weights = None # self.compute_class_weights(labels)
107
-
108
- def train_dataloader(self):
109
- return DataLoader(
110
- self.train_dataset,
111
- batch_size=self.batch_size,
112
- sampler=self._get_weighted_sampler(self.train_dataset.labels.numpy()),
113
- # shuffle=True,
114
- num_workers=self.num_workers,
115
- )
116
-
117
- def val_dataloader(self):
118
- return DataLoader(
119
- self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers
120
- )
121
-
122
- def compute_class_weights(self, labels):
123
- class_weights = compute_class_weight(
124
- class_weight="balanced", classes=np.unique(labels), y=labels
125
- )
126
- return torch.tensor(class_weights, dtype=torch.float32)
127
-
128
- def _get_weighted_sampler(self, labels: np.ndarray) -> WeightedRandomSampler:
129
- """Returns a WeightedRandomSampler based on class weights.
130
-
131
- The weights tensor should contain a weight for each sample, not the class weights.
132
- Have a look at this post for an example: https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2
133
- https://www.maskaravivek.com/post/pytorch-weighted-random-sampler/
134
- """
135
-
136
-
137
- class_sample_count = np.array([len(np.where(labels == label)[0]) for label in np.unique(labels)])
138
- weight = 1. / class_sample_count
139
- samples_weight = np.array([weight[label] for label in labels])
140
- samples_weight = torch.from_numpy(samples_weight)
141
-
142
- # class_weights = compute_class_weight("balanced", classes=np.unique(labels), y=labels)
143
- # class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)
144
- return WeightedRandomSampler(weights=samples_weight, num_samples=len(labels), replacement=True)
 
1
  import os
2
 
 
 
3
  import pandas as pd
4
  import torch
5
+ from torch.utils.data import Dataset
 
6
  from torchvision.io import read_image
 
7
 
8
 
9
  class DRDataset(Dataset):
 
64
 
65
  return image, label
66