bhimrazy commited on
Commit
23fa981
·
1 Parent(s): 60c474e

Add DRDataset and DRDataModule classes

Browse files
Files changed (2) hide show
  1. src/dataset.py +123 -0
  2. src/model.py +69 -0
src/dataset.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import lightning as L
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from sklearn.utils.class_weight import compute_class_weight
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from torchvision.io import read_image
10
+ from torchvision.transforms import v2 as T
11
+
12
+
13
+ class DRDataset(Dataset):
14
+ def __init__(self, csv_path: str, transform=None):
15
+ self.csv_path = csv_path
16
+ self.transform = transform
17
+ self.image_paths, self.labels = self.load_csv_data()
18
+
19
+ def load_csv_data(self):
20
+ # Check if CSV file exists
21
+ if not os.path.isfile(self.csv_path):
22
+ raise FileNotFoundError(f"CSV file '{self.csv_path}' not found.")
23
+
24
+ # Load data from CSV file
25
+ data = pd.read_csv(self.csv_path)
26
+
27
+ # Check if 'image_path' and 'label' columns exist
28
+ if "image_path" not in data.columns or "label" not in data.columns:
29
+ raise ValueError("CSV file must contain 'image_path' and 'label' columns.")
30
+
31
+ # Extract image paths and labels
32
+ image_paths = data["image_path"].tolist()
33
+ labels = data["label"].tolist()
34
+
35
+ # Check if any image paths are invalid
36
+ invalid_image_paths = [
37
+ img_path for img_path in image_paths if not os.path.isfile(img_path)
38
+ ]
39
+ if invalid_image_paths:
40
+ raise FileNotFoundError(f"Invalid image paths found: {invalid_image_paths}")
41
+
42
+ # Convert labels to LongTensor
43
+ labels = torch.LongTensor(labels)
44
+
45
+ return image_paths, labels
46
+
47
+ def __len__(self):
48
+ return len(self.image_paths)
49
+
50
+ def __getitem__(self, idx):
51
+ image_path = self.image_paths[idx]
52
+ label = self.labels[idx]
53
+
54
+ # Load image
55
+ try:
56
+ image = read_image(image_path)
57
+ except Exception as e:
58
+ raise IOError(f"Error loading image at path '{image_path}': {e}")
59
+
60
+ # Apply transformations if provided
61
+ if self.transform:
62
+ try:
63
+ image = self.transform(image)
64
+ except Exception as e:
65
+ raise RuntimeError(
66
+ f"Error applying transformations to image at path '{image_path}': {e}"
67
+ )
68
+
69
+ return image, label
70
+
71
+
72
+ class DRDataModule(L.LightningDataModule):
73
+ def __init__(self, batch_size: int = 8, num_workers: int = 4):
74
+ super().__init__()
75
+ self.batch_size = batch_size
76
+ self.num_workers = num_workers
77
+
78
+ # Define the transformations
79
+ self.train_transform = T.Compose(
80
+ [
81
+ T.Resize((224, 224), antialias=True),
82
+ T.RandomHorizontalFlip(p=0.5),
83
+ T.ToDtype(torch.float32, scale=True),
84
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
85
+ ]
86
+ )
87
+
88
+ self.val_transform = T.Compose(
89
+ [
90
+ T.Resize((224, 224), antialias=True),
91
+ T.ToDtype(torch.float32, scale=True),
92
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
93
+ ]
94
+ )
95
+
96
+ self.num_classes = 5
97
+
98
+ def setup(self, stage=None):
99
+ self.train_dataset = DRDataset("data/train.csv", transform=self.train_transform)
100
+ self.val_dataset = DRDataset("data/val.csv", transform=self.val_transform)
101
+
102
+ # compute class weights
103
+ labels = self.train_dataset.labels.numpy()
104
+ self.class_weights = self.compute_class_weights(labels)
105
+
106
+ def train_dataloader(self):
107
+ return DataLoader(
108
+ self.train_dataset,
109
+ batch_size=self.batch_size,
110
+ shuffle=True,
111
+ num_workers=self.num_workers,
112
+ )
113
+
114
+ def val_dataloader(self):
115
+ return DataLoader(
116
+ self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers
117
+ )
118
+
119
+ def compute_class_weights(self, labels):
120
+ class_weights = compute_class_weight(
121
+ class_weight="balanced", classes=np.unique(labels), y=labels
122
+ )
123
+ return torch.tensor(class_weights, dtype=torch.float32)
src/model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ import torch
3
+ from torch import nn
4
+ from torchmetrics.functional import accuracy
5
+ from torchvision import models
6
+
7
+
8
+ class DRModel(L.LightningModule):
9
+ def __init__(
10
+ self, num_classes: int, learning_rate: float = 2e-4, class_weights=None
11
+ ):
12
+ super().__init__()
13
+ self.save_hyperparameters()
14
+ self.num_classes = num_classes
15
+ self.learning_rate = learning_rate
16
+
17
+ # Define the model
18
+ # self.model = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
19
+ self.model = models.densenet169(weights=models.DenseNet169_Weights.DEFAULT)
20
+ # self.model = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)
21
+ # freeze the feature extractor
22
+ for param in self.model.parameters():
23
+ param.requires_grad = False
24
+ # Change the output layer to have the number of classes
25
+ in_features = self.model.classifier.in_features
26
+ # in_features = 768
27
+ self.model.classifier = nn.Sequential(
28
+ nn.Linear(in_features, in_features // 2),
29
+ nn.ReLU(),
30
+ nn.Dropout(0.1),
31
+ nn.Linear(in_features // 2, num_classes),
32
+ )
33
+
34
+ # Define the loss function
35
+ self.criterion = nn.CrossEntropyLoss(weight=class_weights)
36
+
37
+ def forward(self, x):
38
+ return self.model(x)
39
+
40
+ def training_step(self, batch):
41
+ x, y = batch
42
+ logits = self.model(x)
43
+ loss = self.criterion(logits, y)
44
+ self.log("train_loss", loss, prog_bar=True)
45
+ return loss
46
+
47
+ def validation_step(self, batch, batch_idx):
48
+ x, y = batch
49
+ logits = self.model(x)
50
+ loss = self.criterion(logits, y)
51
+ preds = torch.argmax(logits, dim=1)
52
+ acc = accuracy(preds, y, task="multiclass", num_classes=self.num_classes)
53
+ self.log("val_loss", loss, prog_bar=True)
54
+ self.log("val_acc", acc, prog_bar=True)
55
+
56
+ def configure_optimizers(self):
57
+ optimizer = torch.optim.Adam(
58
+ self.parameters(), lr=self.learning_rate, weight_decay=1e-4
59
+ )
60
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
61
+ return {
62
+ "optimizer": optimizer,
63
+ "lr_scheduler": {
64
+ "scheduler": scheduler,
65
+ "interval": "epoch",
66
+ "monitor": "val_loss",
67
+ },
68
+ }
69
+ # return optimizer