Soutrik
new train model
de7d21e
raw
history blame
5.86 kB
from pathlib import Path
from typing import Union, Tuple, Optional, List
import os
import lightning as L
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_and_extract_archive
from loguru import logger
class CatDogImageDataModule(L.LightningDataModule):
"""DataModule for Cat and Dog Image Classification using ImageFolder."""
def __init__(
self,
root_dir: Union[str, Path] = "data",
data_dir: Union[str, Path] = "cats_and_dogs_filtered",
batch_size: int = 32,
num_workers: int = 4,
train_val_split: List[float] = [0.8, 0.2],
pin_memory: bool = False,
image_size: int = 224,
url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
):
super().__init__()
self.root_dir = Path(root_dir)
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.train_val_split = train_val_split
self.pin_memory = pin_memory
self.image_size = image_size
self.url = url
# Initialize variables for datasets
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
def prepare_data(self):
"""Download the dataset if it doesn't exist."""
self.dataset_path = self.root_dir / self.data_dir
if not self.dataset_path.exists():
logger.info("Downloading and extracting dataset.")
download_and_extract_archive(
url=self.url, download_root=self.root_dir, remove_finished=True
)
logger.info("Download completed.")
def setup(self, stage: Optional[str] = None):
"""Set up the train, validation, and test datasets."""
self.prepare_data()
train_transform = transforms.Compose(
[
transforms.Resize((self.image_size, self.image_size)),
transforms.RandomHorizontalFlip(0.5), # Flip probability increased
transforms.RandomRotation(5), # Reduced rotation for stability
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
test_transform = transforms.Compose(
[
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
train_path = self.dataset_path / "train"
test_path = self.dataset_path / "test"
if stage == "fit" or stage is None:
full_train_dataset = ImageFolder(root=train_path, transform=train_transform)
self.class_names = full_train_dataset.classes
train_size = int(self.train_val_split[0] * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
self.train_dataset, self.val_dataset = random_split(
full_train_dataset, [train_size, val_size]
)
logger.info(
f"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images."
)
if stage == "test" or stage is None:
self.test_dataset = ImageFolder(root=test_path, transform=test_transform)
logger.info(f"Test dataset size: {len(self.test_dataset)} images.")
def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader:
"""Helper function to create a DataLoader."""
return DataLoader(
dataset=dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=shuffle,
)
def train_dataloader(self) -> DataLoader:
return self._create_dataloader(self.train_dataset, shuffle=True)
def val_dataloader(self) -> DataLoader:
return self._create_dataloader(self.val_dataset)
def test_dataloader(self) -> DataLoader:
return self._create_dataloader(self.test_dataset)
def get_class_names(self) -> List[str]:
return self.class_names
if __name__ == "__main__":
# Test the CatDogImageDataModule
import hydra
from omegaconf import DictConfig, OmegaConf
import rootutils
root = rootutils.setup_root(__file__, indicator=".project-root")
@hydra.main(
config_path=str(root / "configs"), version_base="1.3", config_name="train"
)
def test_datamodule(cfg: DictConfig):
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
datamodule = CatDogImageDataModule(
root_dir=cfg.data.root_dir,
data_dir=cfg.data.data_dir,
batch_size=cfg.data.batch_size,
num_workers=cfg.data.num_workers,
train_val_split=cfg.data.train_val_split,
pin_memory=cfg.data.pin_memory,
image_size=cfg.data.image_size,
)
datamodule.setup(stage="fit")
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
datamodule.setup(stage="test")
test_loader = datamodule.test_dataloader()
class_names = datamodule.get_class_names()
logger.info(f"Train loader: {len(train_loader)} batches")
logger.info(f"Validation loader: {len(val_loader)} batches")
logger.info(f"Test loader: {len(test_loader)} batches")
logger.info(f"Class names: {class_names}")
test_datamodule()