from functools import partial import numpy as np import torch import pytorch_lightning as pl from torch.utils.data import DataLoader, Dataset import os, sys os.chdir(sys.path[0]) sys.path.append("..") from lvdm.data.base import Txt2ImgIterableBaseDataset from utils.utils import instantiate_from_config def worker_init_fn(_): worker_info = torch.utils.data.get_worker_info() dataset = worker_info.dataset worker_id = worker_info.id if isinstance(dataset, Txt2ImgIterableBaseDataset): split_size = dataset.num_records // worker_info.num_workers # reset num_records to the true number to retain reliable length information dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] current_id = np.random.choice(len(np.random.get_state()[1]), 1) return np.random.seed(np.random.get_state()[1][current_id] + worker_id) else: return np.random.seed(np.random.get_state()[1][0] + worker_id) class WrappedDataset(Dataset): """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" def __init__(self, dataset): self.data = dataset def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] class DataModuleFromConfig(pl.LightningDataModule): def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, shuffle_val_dataloader=False, train_img=None, test_max_n_samples=None): super().__init__() self.batch_size = batch_size self.dataset_configs = dict() self.num_workers = num_workers if num_workers is not None else batch_size * 2 self.use_worker_init_fn = use_worker_init_fn if train is not None: self.dataset_configs["train"] = train self.train_dataloader = self._train_dataloader if validation is not None: self.dataset_configs["validation"] = validation self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) if test is not None: self.dataset_configs["test"] = test self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) if predict is not None: self.dataset_configs["predict"] = predict self.predict_dataloader = self._predict_dataloader self.img_loader = None self.wrap = wrap self.test_max_n_samples = test_max_n_samples self.collate_fn = None def prepare_data(self): pass def setup(self, stage=None): self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) if self.wrap: for k in self.datasets: self.datasets[k] = WrappedDataset(self.datasets[k]) def _train_dataloader(self): is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None loader = DataLoader(self.datasets["train"], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, worker_init_fn=init_fn, collate_fn=self.collate_fn, ) return loader def _val_dataloader(self, shuffle=False): if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None return DataLoader(self.datasets["validation"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle, collate_fn=self.collate_fn, ) def _test_dataloader(self, shuffle=False): try: is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) except: is_iterable_dataset = isinstance(self.datasets['test'], Txt2ImgIterableBaseDataset) if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None # do not shuffle dataloader for iterable dataset shuffle = shuffle and (not is_iterable_dataset) if self.test_max_n_samples is not None: dataset = torch.utils.data.Subset(self.datasets["test"], list(range(self.test_max_n_samples))) else: dataset = self.datasets["test"] return DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle, collate_fn=self.collate_fn, ) def _predict_dataloader(self, shuffle=False): if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None return DataLoader(self.datasets["predict"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn, collate_fn=self.collate_fn, )