Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from typing import Union, Tuple, Optional, List | |
import os | |
import lightning as L | |
from torch.utils.data import DataLoader, random_split | |
from torchvision import transforms | |
from torchvision.datasets import ImageFolder | |
from torchvision.datasets.utils import download_and_extract_archive | |
from loguru import logger | |
class CatDogImageDataModule(L.LightningDataModule): | |
"""DataModule for Cat and Dog Image Classification using ImageFolder.""" | |
def __init__( | |
self, | |
root_dir: Union[str, Path] = "data", | |
data_dir: Union[str, Path] = "cats_and_dogs_filtered", | |
batch_size: int = 32, | |
num_workers: int = 4, | |
train_val_split: List[float] = [0.8, 0.2], | |
pin_memory: bool = False, | |
image_size: int = 224, | |
url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip", | |
): | |
super().__init__() | |
self.root_dir = Path(root_dir) | |
self.data_dir = data_dir | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.train_val_split = train_val_split | |
self.pin_memory = pin_memory | |
self.image_size = image_size | |
self.url = url | |
# Initialize variables for datasets | |
self.train_dataset = None | |
self.val_dataset = None | |
self.test_dataset = None | |
def prepare_data(self): | |
"""Download the dataset if it doesn't exist.""" | |
self.dataset_path = self.root_dir / self.data_dir | |
if not self.dataset_path.exists(): | |
logger.info("Downloading and extracting dataset.") | |
download_and_extract_archive( | |
url=self.url, download_root=self.root_dir, remove_finished=True | |
) | |
logger.info("Download completed.") | |
def setup(self, stage: Optional[str] = None): | |
"""Set up the train, validation, and test datasets.""" | |
self.prepare_data() | |
train_transform = transforms.Compose( | |
[ | |
transforms.Resize((self.image_size, self.image_size)), | |
transforms.RandomHorizontalFlip(0.5), # Flip probability increased | |
transforms.RandomRotation(5), # Reduced rotation for stability | |
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
), | |
] | |
) | |
test_transform = transforms.Compose( | |
[ | |
transforms.Resize((self.image_size, self.image_size)), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
), | |
] | |
) | |
train_path = self.dataset_path / "train" | |
test_path = self.dataset_path / "test" | |
if stage == "fit" or stage is None: | |
full_train_dataset = ImageFolder(root=train_path, transform=train_transform) | |
self.class_names = full_train_dataset.classes | |
train_size = int(self.train_val_split[0] * len(full_train_dataset)) | |
val_size = len(full_train_dataset) - train_size | |
self.train_dataset, self.val_dataset = random_split( | |
full_train_dataset, [train_size, val_size] | |
) | |
logger.info( | |
f"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images." | |
) | |
if stage == "test" or stage is None: | |
self.test_dataset = ImageFolder(root=test_path, transform=test_transform) | |
logger.info(f"Test dataset size: {len(self.test_dataset)} images.") | |
def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader: | |
"""Helper function to create a DataLoader.""" | |
return DataLoader( | |
dataset=dataset, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
pin_memory=self.pin_memory, | |
shuffle=shuffle, | |
) | |
def train_dataloader(self) -> DataLoader: | |
return self._create_dataloader(self.train_dataset, shuffle=True) | |
def val_dataloader(self) -> DataLoader: | |
return self._create_dataloader(self.val_dataset) | |
def test_dataloader(self) -> DataLoader: | |
return self._create_dataloader(self.test_dataset) | |
def get_class_names(self) -> List[str]: | |
return self.class_names | |
if __name__ == "__main__": | |
# Test the CatDogImageDataModule | |
import hydra | |
from omegaconf import DictConfig, OmegaConf | |
import rootutils | |
root = rootutils.setup_root(__file__, indicator=".project-root") | |
def test_datamodule(cfg: DictConfig): | |
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}") | |
datamodule = CatDogImageDataModule( | |
root_dir=cfg.data.root_dir, | |
data_dir=cfg.data.data_dir, | |
batch_size=cfg.data.batch_size, | |
num_workers=cfg.data.num_workers, | |
train_val_split=cfg.data.train_val_split, | |
pin_memory=cfg.data.pin_memory, | |
image_size=cfg.data.image_size, | |
) | |
datamodule.setup(stage="fit") | |
train_loader = datamodule.train_dataloader() | |
val_loader = datamodule.val_dataloader() | |
datamodule.setup(stage="test") | |
test_loader = datamodule.test_dataloader() | |
class_names = datamodule.get_class_names() | |
logger.info(f"Train loader: {len(train_loader)} batches") | |
logger.info(f"Validation loader: {len(val_loader)} batches") | |
logger.info(f"Test loader: {len(test_loader)} batches") | |
logger.info(f"Class names: {class_names}") | |
test_datamodule() | |