{ "cells": [ { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "plaintext" } }, "source": [ "In this notebook, we will be discussing about the pytorch lightning datamodule library with images in a folder strutcture with folders as class labels. We will be using the cats and dogs dataset from kaggle. The dataset can be downloaded from [here](https://www.kaggle.com/c/dogs-vs-cats/data). The dataset contains 25000 images of cats and dogs. We will be using 20000 images for training and 5000 images for validation. The images are in a folder structure with folders as class labels." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "application/javascript": "IPython.notebook.set_autosave_interval(300000)" }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Autosaving every 300 seconds\n" ] } ], "source": [ "%autosave 300\n", "%load_ext autoreload\n", "%autoreload 2\n", "%reload_ext autoreload\n", "%config Completer.use_jedi = False" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws\n" ] } ], "source": [ "import os\n", "\n", "os.chdir(\"..\")\n", "print(os.getcwd())" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda/envs/emlo_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from pathlib import Path\n", "from typing import Union, Tuple, Optional, List\n", "import os\n", "import lightning as L\n", "from torch.utils.data import DataLoader, random_split\n", "from torchvision import transforms\n", "from torchvision.datasets import ImageFolder\n", "from torchvision.datasets.utils import download_and_extract_archive\n", "from loguru import logger" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "class CatDogImageDataModule(L.LightningDataModule):\n", " \"\"\"DataModule for Cat and Dog Image Classification using ImageFolder.\"\"\"\n", "\n", " def __init__(\n", " self,\n", " data_root: Union[str, Path] = \"data\",\n", " data_dir: Union[str, Path] = \"cats_and_dogs_filtered\",\n", " batch_size: int = 32,\n", " num_workers: int = 4,\n", " train_val_split: List[float] = [0.8, 0.2],\n", " pin_memory: bool = False,\n", " image_size: int = 224,\n", " url: str = \"https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\",\n", " ):\n", " super().__init__()\n", " self.data_root = Path(data_root)\n", " self.data_dir = data_dir\n", " self.batch_size = batch_size\n", " self.num_workers = num_workers\n", " self.train_val_split = train_val_split\n", " self.pin_memory = pin_memory\n", " self.image_size = image_size\n", " self.url = url\n", "\n", " # Initialize variables for datasets\n", " self.train_dataset = None\n", " self.val_dataset = None\n", " self.test_dataset = None\n", "\n", " def prepare_data(self):\n", " \"\"\"Download the dataset if it doesn't exist.\"\"\"\n", " self.dataset_path = self.data_root / self.data_dir\n", " if not self.dataset_path.exists():\n", " logger.info(\"Downloading and extracting dataset.\")\n", " download_and_extract_archive(\n", " url=self.url, download_root=self.data_root, remove_finished=True\n", " )\n", " logger.info(\"Download completed.\")\n", "\n", " def setup(self, stage: Optional[str] = None):\n", " \"\"\"Set up the train, validation, and test datasets.\"\"\"\n", "\n", " train_transform = transforms.Compose(\n", " [\n", " transforms.Resize((self.image_size, self.image_size)),\n", " transforms.RandomHorizontalFlip(0.1),\n", " transforms.RandomRotation(10),\n", " transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),\n", " transforms.RandomAutocontrast(0.1),\n", " transforms.RandomAdjustSharpness(2, 0.1),\n", " transforms.ToTensor(),\n", " transforms.Normalize(\n", " mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n", " ),\n", " ]\n", " )\n", "\n", " test_transform = transforms.Compose(\n", " [\n", " transforms.Resize((self.image_size, self.image_size)),\n", " transforms.ToTensor(),\n", " transforms.Normalize(\n", " mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n", " ),\n", " ]\n", " )\n", "\n", " train_path = self.dataset_path / \"train\"\n", " test_path = self.dataset_path / \"test\"\n", "\n", " self.prepare_data()\n", "\n", " if stage == \"fit\" or stage is None:\n", " full_train_dataset = ImageFolder(root=train_path, transform=train_transform)\n", " self.class_names = full_train_dataset.classes\n", " train_size = int(self.train_val_split[0] * len(full_train_dataset))\n", " val_size = len(full_train_dataset) - train_size\n", " self.train_dataset, self.val_dataset = random_split(\n", " full_train_dataset, [train_size, val_size]\n", " )\n", " logger.info(\n", " f\"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images.\"\n", " )\n", "\n", " if stage == \"test\" or stage is None:\n", " self.test_dataset = ImageFolder(root=test_path, transform=test_transform)\n", " logger.info(f\"Test dataset size: {len(self.test_dataset)} images.\")\n", "\n", " def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader:\n", " \"\"\"Helper function to create a DataLoader.\"\"\"\n", " return DataLoader(\n", " dataset=dataset,\n", " batch_size=self.batch_size,\n", " num_workers=self.num_workers,\n", " pin_memory=self.pin_memory,\n", " shuffle=shuffle,\n", " )\n", "\n", " def train_dataloader(self) -> DataLoader:\n", " return self._create_dataloader(self.train_dataset, shuffle=True)\n", "\n", " def val_dataloader(self) -> DataLoader:\n", " return self._create_dataloader(self.val_dataset)\n", "\n", " def test_dataloader(self) -> DataLoader:\n", " return self._create_dataloader(self.test_dataset)\n", "\n", " def get_class_names(self) -> List[str]:\n", " return self.class_names" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "datamodule = CatDogImageDataModule(\n", " data_root=\"data\",\n", " data_dir=\"cats_and_dogs_filtered\",\n", " batch_size=32,\n", " num_workers=4,\n", " train_val_split=[0.8, 0.2],\n", " pin_memory=True,\n", " image_size=224,\n", " url=\"https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\",\n", ")" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-10 05:37:17.840\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m81\u001b[0m - \u001b[1mTrain/Validation split: 2241 train, 561 validation images.\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2024-11-10 05:37:17.910\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m87\u001b[0m - \u001b[1mTest dataset size: 198 images.\u001b[0m\n" ] } ], "source": [ "datamodule.prepare_data()\n", "datamodule.setup()\n", "class_names = datamodule.get_class_names()\n", "train_dataloader = datamodule.train_dataloader()\n", "val_dataloader= datamodule.val_dataloader()\n", "test_dataloader= datamodule.test_dataloader()" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['cats', 'dogs']" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class_names" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "emlo_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 }