Spaces:
Sleeping
Sleeping
from lightning import LightningDataModule | |
from torch.utils.data import Dataset, DataLoader | |
class Datamodule(LightningDataModule): | |
def __init__( | |
self, | |
train_dataset: Dataset, | |
eval_dataset: Dataset, | |
batch_train_size: int, | |
num_workers: int, | |
eval_batch_size: int = None, | |
): | |
super().__init__() | |
self.train_dataset = train_dataset | |
self.eval_dataset = eval_dataset | |
self.batch_train_size = batch_train_size | |
self.eval_batch_size = ( | |
eval_batch_size if eval_batch_size is not None else batch_train_size | |
) | |
self.num_workers = num_workers | |
def train_dataloader(self) -> DataLoader: | |
"""Load train set loader.""" | |
persistent_workers = True if self.num_workers > 0 else False | |
dataloader = DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_train_size, | |
num_workers=self.num_workers, | |
pin_memory=True, | |
persistent_workers=persistent_workers, | |
) | |
return dataloader | |
def val_dataloader(self) -> DataLoader: | |
"""Load val set loader.""" | |
persistent_workers = True if self.num_workers > 0 else False | |
dataloader = DataLoader( | |
self.eval_dataset, | |
batch_size=self.eval_batch_size, | |
num_workers=self.num_workers, | |
pin_memory=True, | |
persistent_workers=persistent_workers, | |
) | |
return dataloader | |
def predict_dataloader(self) -> DataLoader: | |
"""Load predict set loader.""" | |
dataloader = DataLoader( | |
self.eval_dataset, | |
batch_size=self.eval_batch_size, | |
num_workers=self.num_workers, | |
) | |
return dataloader | |
def test_dataloader(self) -> DataLoader: | |
"""Load test set loader.""" | |
dataloader = DataLoader( | |
self.eval_dataset, | |
batch_size=self.eval_batch_size, | |
num_workers=self.num_workers, | |
) | |
return dataloader | |