In this notebook, we will be discussing about the pytorch lightning datamodule library with images in a folder strutcture with folders as class labels. We will be using the cats and dogs dataset from kaggle. The dataset can be downloaded from [here](https://www.kaggle.com/c/dogs-vs-cats/data). The dataset contains 25000 images of cats and dogs. We will be using 20000 images for training and 5000 images for validation. The images are in a folder structure with folders as class labels.

In [1]:
%autosave 300
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
%config Completer.use_jedi = False

Autosaving every 300 seconds


In [2]:
import os

os.chdir("..")
print(os.getcwd())

/mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws


In [3]:
from pathlib import Path
from typing import Union, Tuple, Optional, List
import os
import lightning as L
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_and_extract_archive
from loguru import logger

 from .autonotebook import tqdm as notebook_tqdm


In [32]:
class CatDogImageDataModule(L.LightningDataModule):
 """DataModule for Cat and Dog Image Classification using ImageFolder."""

 def __init__(
 self,
 data_root: Union[str, Path] = "data",
 data_dir: Union[str, Path] = "cats_and_dogs_filtered",
 batch_size: int = 32,
 num_workers: int = 4,
 train_val_split: List[float] = [0.8, 0.2],
 pin_memory: bool = False,
 image_size: int = 224,
 url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
 ):
 super().__init__()
 self.data_root = Path(data_root)
 self.data_dir = data_dir
 self.batch_size = batch_size
 self.num_workers = num_workers
 self.train_val_split = train_val_split
 self.pin_memory = pin_memory
 self.image_size = image_size
 self.url = url

 # Initialize variables for datasets
 self.train_dataset = None
 self.val_dataset = None
 self.test_dataset = None

 def prepare_data(self):
 """Download the dataset if it doesn't exist."""
 self.dataset_path = self.data_root / self.data_dir
 if not self.dataset_path.exists():
 logger.info("Downloading and extracting dataset.")
 download_and_extract_archive(
 url=self.url, download_root=self.data_root, remove_finished=True
 )
 logger.info("Download completed.")

 def setup(self, stage: Optional[str] = None):
 """Set up the train, validation, and test datasets."""

 train_transform = transforms.Compose(
 [
 transforms.Resize((self.image_size, self.image_size)),
 transforms.RandomHorizontalFlip(0.1),
 transforms.RandomRotation(10),
 transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
 transforms.RandomAutocontrast(0.1),
 transforms.RandomAdjustSharpness(2, 0.1),
 transforms.ToTensor(),
 transforms.Normalize(
 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
 ),
 ]
 )

 test_transform = transforms.Compose(
 [
 transforms.Resize((self.image_size, self.image_size)),
 transforms.ToTensor(),
 transforms.Normalize(
 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
 ),
 ]
 )

 train_path = self.dataset_path / "train"
 test_path = self.dataset_path / "test"

 self.prepare_data()

 if stage == "fit" or stage is None:
 full_train_dataset = ImageFolder(root=train_path, transform=train_transform)
 self.class_names = full_train_dataset.classes
 train_size = int(self.train_val_split[0] * len(full_train_dataset))
 val_size = len(full_train_dataset) - train_size
 self.train_dataset, self.val_dataset = random_split(
 full_train_dataset, [train_size, val_size]
 )
 logger.info(
 f"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images."
 )

 if stage == "test" or stage is None:
 self.test_dataset = ImageFolder(root=test_path, transform=test_transform)
 logger.info(f"Test dataset size: {len(self.test_dataset)} images.")

 def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader:
 """Helper function to create a DataLoader."""
 return DataLoader(
 dataset=dataset,
 batch_size=self.batch_size,
 num_workers=self.num_workers,
 pin_memory=self.pin_memory,
 shuffle=shuffle,
 )

 def train_dataloader(self) -> DataLoader:
 return self._create_dataloader(self.train_dataset, shuffle=True)

 def val_dataloader(self) -> DataLoader:
 return self._create_dataloader(self.val_dataset)

 def test_dataloader(self) -> DataLoader:
 return self._create_dataloader(self.test_dataset)

 def get_class_names(self) -> List[str]:
 return self.class_names

In [33]:
datamodule = CatDogImageDataModule(
 data_root="data",
 data_dir="cats_and_dogs_filtered",
 batch_size=32,
 num_workers=4,
 train_val_split=[0.8, 0.2],
 pin_memory=True,
 image_size=224,
 url="https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
)

In [35]:
datamodule.prepare_data()
datamodule.setup()
class_names = datamodule.get_class_names()
train_dataloader = datamodule.train_dataloader()
val_dataloader= datamodule.val_dataloader()
test_dataloader= datamodule.test_dataloader()

[32m2024-11-10 05:37:17.840[0m | [1mINFO [0m | [36m__main__[0m:[36msetup[0m:[36m81[0m - [1mTrain/Validation split: 2241 train, 561 validation images.[0m


[32m2024-11-10 05:37:17.910[0m | [1mINFO [0m | [36m__main__[0m:[36msetup[0m:[36m87[0m - [1mTest dataset size: 198 images.[0m


In [36]:
class_names

['cats', 'dogs']