from typing import Dict, Any import pytorch_lightning as pl from datasets import load_dataset from torch.utils.data import DataLoader class MNISTDataModule(pl.LightningDataModule): def __init__(self, config: Dict[str, Any]): super().__init__() self.config = config def setup(self, stage=None): self.dataset = load_dataset('mnist') self.dataset = self.dataset.with_transform(self.config.transform_dataset) def train_dataloader(self): return DataLoader( self.dataset['train'], batch_size=self.config.batch_size, shuffle=True ) def val_dataloader(self): return DataLoader( self.dataset['test'], # Using test set as validation batch_size=self.config.batch_size )