bhimrazy's picture
updates the dataset file
159c02d
raw
history blame
No virus
5.37 kB
import os
import lightning as L
import numpy as np
import pandas as pd
import torch
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision.io import read_image
from torchvision.transforms import v2 as T
class DRDataset(Dataset):
def __init__(self, csv_path: str, transform=None):
self.csv_path = csv_path
self.transform = transform
self.image_paths, self.labels = self.load_csv_data()
def load_csv_data(self):
# Check if CSV file exists
if not os.path.isfile(self.csv_path):
raise FileNotFoundError(f"CSV file '{self.csv_path}' not found.")
# Load data from CSV file
data = pd.read_csv(self.csv_path)
# Check if 'image_path' and 'label' columns exist
if "image_path" not in data.columns or "label" not in data.columns:
raise ValueError("CSV file must contain 'image_path' and 'label' columns.")
# Extract image paths and labels
image_paths = data["image_path"].tolist()
labels = data["label"].tolist()
# Check if any image paths are invalid
invalid_image_paths = [
img_path for img_path in image_paths if not os.path.isfile(img_path)
]
if invalid_image_paths:
raise FileNotFoundError(f"Invalid image paths found: {invalid_image_paths}")
# Convert labels to LongTensor
labels = torch.LongTensor(labels)
return image_paths, labels
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
label = self.labels[idx]
# Load image
try:
image = read_image(image_path)
except Exception as e:
raise IOError(f"Error loading image at path '{image_path}': {e}")
# Apply transformations if provided
if self.transform:
try:
image = self.transform(image)
except Exception as e:
raise RuntimeError(
f"Error applying transformations to image at path '{image_path}': {e}"
)
return image, label
class DRDataModule(L.LightningDataModule):
def __init__(self, batch_size: int = 8, num_workers: int = 4):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
# Define the transformations
self.train_transform = T.Compose(
[
T.Resize((224, 224), antialias=True),
T.RandomAffine(degrees=10, translate=(0.01, 0.01), scale=(0.99, 1.01)),
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.01),
T.RandomHorizontalFlip(p=0.5),
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
self.val_transform = T.Compose(
[
T.Resize((224, 224), antialias=True),
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
self.num_classes = 5
def setup(self, stage=None):
self.train_dataset = DRDataset("data/train.csv", transform=self.train_transform)
self.val_dataset = DRDataset("data/val.csv", transform=self.val_transform)
# compute class weights
labels = self.train_dataset.labels.numpy()
self.class_weights = None # self.compute_class_weights(labels)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
sampler=self._get_weighted_sampler(self.train_dataset.labels.numpy()),
# shuffle=True,
num_workers=self.num_workers,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers
)
def compute_class_weights(self, labels):
class_weights = compute_class_weight(
class_weight="balanced", classes=np.unique(labels), y=labels
)
return torch.tensor(class_weights, dtype=torch.float32)
def _get_weighted_sampler(self, labels: np.ndarray) -> WeightedRandomSampler:
"""Returns a WeightedRandomSampler based on class weights.
The weights tensor should contain a weight for each sample, not the class weights.
Have a look at this post for an example: https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2
https://www.maskaravivek.com/post/pytorch-weighted-random-sampler/
"""
class_sample_count = np.array([len(np.where(labels == label)[0]) for label in np.unique(labels)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[label] for label in labels])
samples_weight = torch.from_numpy(samples_weight)
# class_weights = compute_class_weight("balanced", classes=np.unique(labels), y=labels)
# class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)
return WeightedRandomSampler(weights=samples_weight, num_samples=len(labels), replacement=True)