""" Datasets for core experimental results """ from functools import partial from pathlib import Path import torch import torchaudio.functional as TF import torchvision from einops import rearrange from ..utils.util import is_list def deprecated(cls_or_func): def _deprecated(*args, **kwargs): print(f"{cls_or_func} is deprecated") return cls_or_func(*args, **kwargs) return _deprecated # Default data path is environment variable or hippo/data default_data_path = Path(__file__).parent.parent.parent.absolute() default_data_path = default_data_path / "raw_data" class DefaultCollateMixin: """Controls collating in the DataLoader The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments, constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor. """ @classmethod def _collate_callback(cls, x, *args, **kwargs): """ Modify the behavior of the default _collate method. """ return x _collate_arg_names = [] @classmethod def _return_callback(cls, return_value, *args, **kwargs): """ Modify the return value of the collate_fn. Assign a name to each element of the returned tuple beyond the (x, y) pairs See InformerSequenceDataset for an example of this being used """ x, y, *z = return_value assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset" return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)} @classmethod def _collate(cls, batch, *args, **kwargs): # From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py elem = batch[0] if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) storage = elem.storage()._new_shared(numel) out = elem.new(storage) x = torch.stack(batch, dim=0, out=out) # Insert custom functionality into the collate_fn x = cls._collate_callback(x, *args, **kwargs) return x else: return torch.tensor(batch) @classmethod def _collate_fn(cls, batch, *args, **kwargs): """ Default collate function. Generally accessed by the dataloader() methods to pass into torch DataLoader Arguments: batch: list of (x, y) pairs args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback """ x, y, *z = zip(*batch) x = cls._collate(x, *args, **kwargs) y = cls._collate(y) z = [cls._collate(z_) for z_ in z] return_value = (x, y, *z) return cls._return_callback(return_value, *args, **kwargs) # List of loader arguments to pass into collate_fn collate_args = [] def _dataloader(self, dataset, **loader_args): collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args} loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args} loader_cls = loader_registry[loader_args.pop("_name_", None)] return loader_cls( dataset=dataset, collate_fn=partial(self._collate_fn, **collate_args), **loader_args, ) class SequenceResolutionCollateMixin(DefaultCollateMixin): """self.collate_fn(resolution) produces a collate function that subsamples elements of the sequence""" @classmethod def _collate_callback(cls, x, resolution=None): if resolution is None: pass elif is_list(resolution): # Resize to first resolution, then apply resampling technique # Sample to first resolution x = x.squeeze(-1) # (B, L) L = x.size(1) x = x[:, ::resolution[0]] # assume length is first axis after batch _L = L // resolution[0] for r in resolution[1:]: x = TF.resample(x, _L, L//r) _L = L // r x = x.unsqueeze(-1) # (B, L, 1) else: # Assume x is (B, L_0, L_1, ..., L_k, C) for x.ndim > 2 and (B, L) for x.ndim = 2 assert x.ndim >= 2 n_resaxes = max(1, x.ndim - 2) # [AG 22/07/02] this line looks suspicious... are there cases with 2 axes? # rearrange: b (l_0 res_0) (l_1 res_1) ... (l_k res_k) ... -> res_0 res_1 .. res_k b l_0 l_1 ... lhs = "b " + " ".join([f"(l{i} res{i})" for i in range(n_resaxes)]) + " ..." rhs = " ".join([f"res{i}" for i in range(n_resaxes)]) + " b " + " ".join([f"l{i}" for i in range(n_resaxes)]) + " ..." x = rearrange(x, lhs + " -> " + rhs, **{f'res{i}': resolution for i in range(n_resaxes)}) x = x[tuple([0] * n_resaxes)] return x @classmethod def _return_callback(cls, return_value, resolution=None): return (*return_value, {"rate": resolution}) collate_args = ['resolution'] class ImageResolutionCollateMixin(SequenceResolutionCollateMixin): """self.collate_fn(resolution, img_size) produces a collate function that resizes inputs to size img_size/resolution""" _interpolation = torchvision.transforms.InterpolationMode.BILINEAR _antialias = True @classmethod def _collate_callback(cls, x, resolution=None, img_size=None, channels_last=True): if x.ndim < 4: return super()._collate_callback(x, resolution=resolution) if img_size is None: x = super()._collate_callback(x, resolution=resolution) else: x = rearrange(x, 'b ... c -> b c ...') if channels_last else x _size = round(img_size/resolution) x = torchvision.transforms.functional.resize( x, size=[_size, _size], interpolation=cls._interpolation, antialias=cls._antialias, ) x = rearrange(x, 'b c ... -> b ... c') if channels_last else x return x @classmethod def _return_callback(cls, return_value, resolution=None, img_size=None, channels_last=True): return (*return_value, {"rate": resolution}) collate_args = ['resolution', 'img_size', 'channels_last'] class TBPTTDataLoader(torch.utils.data.DataLoader): """ Adapted from https://github.com/deepsound-project/samplernn-pytorch """ def __init__( self, dataset, batch_size, chunk_len, overlap_len, *args, **kwargs ): super().__init__(dataset, batch_size, *args, **kwargs) assert chunk_len is not None and overlap_len is not None, "TBPTTDataLoader: chunk_len and overlap_len must be specified." # Zero padding value, given by the dataset self.zero = dataset.zero if hasattr(dataset, "zero") else 0 # Size of the chunks to be fed into the model self.chunk_len = chunk_len # Keep `overlap_len` from the previous chunk (e.g. SampleRNN requires this) self.overlap_len = overlap_len def __iter__(self): for batch in super().__iter__(): x, y, z = batch # (B, L) (B, L, 1) {'lengths': (B,)} # Pad with self.overlap_len - 1 zeros pad = lambda x, val: torch.cat([x.new_zeros((x.shape[0], self.overlap_len - 1, *x.shape[2:])) + val, x], dim=1) x = pad(x, self.zero) y = pad(y, 0) z = { k: pad(v, 0) for k, v in z.items() if v.ndim > 1 } _, seq_len, *_ = x.shape reset = True for seq_begin in list(range(self.overlap_len - 1, seq_len, self.chunk_len))[:-1]: from_index = seq_begin - self.overlap_len + 1 to_index = seq_begin + self.chunk_len # TODO: check this # Ensure divisible by overlap_len if self.overlap_len > 0: to_index = min(to_index, seq_len - ((seq_len - self.overlap_len + 1) % self.overlap_len)) x_chunk = x[:, from_index:to_index] if len(y.shape) == 3: y_chunk = y[:, seq_begin:to_index] else: y_chunk = y z_chunk = {k: v[:, from_index:to_index] for k, v in z.items() if len(v.shape) > 1} yield (x_chunk, y_chunk, {**z_chunk, "reset": reset}) reset = False def __len__(self): raise NotImplementedError() # class SequenceDataset(LightningDataModule): # [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just # provide our own class with the same core methods as LightningDataModule (e.g. setup) class SequenceDataset(DefaultCollateMixin): registry = {} _name_ = NotImplementedError("Dataset must have shorthand name") # Since subclasses do not specify __init__ which is instead handled by this class # Subclasses can provide a list of default arguments which are automatically registered as attributes # TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features of this class # such as the _name_ and d_input/d_output @property def init_defaults(self): return {} # https://www.python.org/dev/peps/pep-0487/#subclass-registration def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls.registry[cls._name_] = cls def __init__(self, _name_, data_dir=None, **dataset_cfg): assert _name_ == self._name_ self.data_dir = Path(data_dir).absolute() if data_dir is not None else None # Add all arguments to self init_args = self.init_defaults.copy() init_args.update(dataset_cfg) for k, v in init_args.items(): setattr(self, k, v) # The train, val, test datasets must be set by `setup()` self.dataset_train = self.dataset_val = self.dataset_test = None self.init() def init(self): """Hook called at end of __init__, override this instead of __init__""" pass def setup(self): """This method should set self.dataset_train, self.dataset_val, and self.dataset_test.""" raise NotImplementedError def split_train_val(self, val_split): """ Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair. """ train_len = int(len(self.dataset_train) * (1.0 - val_split)) self.dataset_train, self.dataset_val = torch.utils.data.random_split( self.dataset_train, (train_len, len(self.dataset_train) - train_len), generator=torch.Generator().manual_seed( getattr(self, "seed", 42) ), # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us ) def train_dataloader(self, **kwargs): return self._train_dataloader(self.dataset_train, **kwargs) def _train_dataloader(self, dataset, **kwargs): if dataset is None: return kwargs['shuffle'] = 'sampler' not in kwargs # shuffle cant be True if we have custom sampler return self._dataloader(dataset, **kwargs) def val_dataloader(self, **kwargs): return self._eval_dataloader(self.dataset_val, **kwargs) def test_dataloader(self, **kwargs): return self._eval_dataloader(self.dataset_test, **kwargs) def _eval_dataloader(self, dataset, **kwargs): if dataset is None: return # Note that shuffle=False by default return self._dataloader(dataset, **kwargs) def __str__(self): return self._name_ class ResolutionSequenceDataset(SequenceDataset, SequenceResolutionCollateMixin): def _train_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): if train_resolution is None: train_resolution = [1] if not is_list(train_resolution): train_resolution = [train_resolution] assert len(train_resolution) == 1, "Only one train resolution supported for now." return super()._train_dataloader(dataset, resolution=train_resolution[0], **kwargs) def _eval_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): if dataset is None: return if eval_resolutions is None: eval_resolutions = [1] if not is_list(eval_resolutions): eval_resolutions = [eval_resolutions] dataloaders = [] for resolution in eval_resolutions: dataloaders.append(super()._eval_dataloader(dataset, resolution=resolution, **kwargs)) return ( { None if res == 1 else str(res): dl for res, dl in zip(eval_resolutions, dataloaders) } if dataloaders is not None else None ) class ImageResolutionSequenceDataset(ResolutionSequenceDataset, ImageResolutionCollateMixin): pass # Registry for dataloader class loader_registry = { "tbptt": TBPTTDataLoader, None: torch.utils.data.DataLoader, # default case }