bhimrazy commited on
Commit
356b6f2
1 Parent(s): 945b303

Adds support for weighted random sampler

Browse files
Files changed (1) hide show
  1. src/dataset.py +26 -5
src/dataset.py CHANGED
@@ -5,7 +5,7 @@ import numpy as np
5
  import pandas as pd
6
  import torch
7
  from sklearn.utils.class_weight import compute_class_weight
8
- from torch.utils.data import DataLoader, Dataset
9
  from torchvision.io import read_image
10
  from torchvision.transforms import v2 as T
11
 
@@ -78,7 +78,9 @@ class DRDataModule(L.LightningDataModule):
78
  # Define the transformations
79
  self.train_transform = T.Compose(
80
  [
81
- T.Resize((224, 224), antialias=True),
 
 
82
  T.RandomHorizontalFlip(p=0.5),
83
  T.ToDtype(torch.float32, scale=True),
84
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
@@ -87,7 +89,7 @@ class DRDataModule(L.LightningDataModule):
87
 
88
  self.val_transform = T.Compose(
89
  [
90
- T.Resize((224, 224), antialias=True),
91
  T.ToDtype(torch.float32, scale=True),
92
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
93
  ]
@@ -101,13 +103,14 @@ class DRDataModule(L.LightningDataModule):
101
 
102
  # compute class weights
103
  labels = self.train_dataset.labels.numpy()
104
- self.class_weights = self.compute_class_weights(labels)
105
 
106
  def train_dataloader(self):
107
  return DataLoader(
108
  self.train_dataset,
109
  batch_size=self.batch_size,
110
- shuffle=True,
 
111
  num_workers=self.num_workers,
112
  )
113
 
@@ -121,3 +124,21 @@ class DRDataModule(L.LightningDataModule):
121
  class_weight="balanced", classes=np.unique(labels), y=labels
122
  )
123
  return torch.tensor(class_weights, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import pandas as pd
6
  import torch
7
  from sklearn.utils.class_weight import compute_class_weight
8
+ from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
9
  from torchvision.io import read_image
10
  from torchvision.transforms import v2 as T
11
 
 
78
  # Define the transformations
79
  self.train_transform = T.Compose(
80
  [
81
+ T.Resize((512, 512), antialias=True),
82
+ T.RandomAffine(degrees=10, translate=(0.01, 0.01), scale=(0.99, 1.01)),
83
+ T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.01),
84
  T.RandomHorizontalFlip(p=0.5),
85
  T.ToDtype(torch.float32, scale=True),
86
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
89
 
90
  self.val_transform = T.Compose(
91
  [
92
+ T.Resize((512, 512), antialias=True),
93
  T.ToDtype(torch.float32, scale=True),
94
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
95
  ]
 
103
 
104
  # compute class weights
105
  labels = self.train_dataset.labels.numpy()
106
+ self.class_weights = None # self.compute_class_weights(labels)
107
 
108
  def train_dataloader(self):
109
  return DataLoader(
110
  self.train_dataset,
111
  batch_size=self.batch_size,
112
+ sampler=self._get_weighted_sampler(self.train_dataset.labels.numpy()),
113
+ # shuffle=True,
114
  num_workers=self.num_workers,
115
  )
116
 
 
124
  class_weight="balanced", classes=np.unique(labels), y=labels
125
  )
126
  return torch.tensor(class_weights, dtype=torch.float32)
127
+
128
+ def _get_weighted_sampler(self, labels: np.ndarray) -> WeightedRandomSampler:
129
+ """Returns a WeightedRandomSampler based on class weights.
130
+
131
+ The weights tensor should contain a weight for each sample, not the class weights.
132
+ Have a look at this post for an example: https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2
133
+ https://www.maskaravivek.com/post/pytorch-weighted-random-sampler/
134
+ """
135
+
136
+
137
+ class_sample_count = np.array([len(np.where(labels == label)[0]) for label in np.unique(labels)])
138
+ weight = 1. / class_sample_count
139
+ samples_weight = np.array([weight[label] for label in labels])
140
+ samples_weight = torch.from_numpy(samples_weight)
141
+
142
+ # class_weights = compute_class_weight("balanced", classes=np.unique(labels), y=labels)
143
+ # class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)
144
+ return WeightedRandomSampler(weights=samples_weight, num_samples=len(labels), replacement=True)