""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import gzip import logging import os import random as rnd import tarfile import zipfile import random from typing import List from tqdm import tqdm import decord from decord import VideoReader import webdataset as wds import numpy as np import torch from torch.utils.data.dataset import IterableDataset from medomni.common.registry import registry from medomni.datasets.datasets.base_dataset import ConcatDataset decord.bridge.set_bridge("torch") MAX_INT = registry.get("MAX_INT") class ChainDataset(wds.DataPipeline): r"""Dataset for chaining multiple :class:`DataPipeline` s. This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. Args: datasets (iterable of IterableDataset): datasets to be chained together """ def __init__(self, datasets: List[wds.DataPipeline]) -> None: super().__init__() self.datasets = datasets self.prob = [] self.names = [] for dataset in self.datasets: if hasattr(dataset, 'name'): self.names.append(dataset.name) else: self.names.append('Unknown') if hasattr(dataset, 'sample_ratio'): self.prob.append(dataset.sample_ratio) else: self.prob.append(1) logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") def __iter__(self): datastreams = [iter(dataset) for dataset in self.datasets] while True: select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] yield next(select_datastream) def apply_to_sample(f, sample): if len(sample) == 0: return {} def _apply(x): if torch.is_tensor(x): return f(x) elif isinstance(x, dict): return {key: _apply(value) for key, value in x.items()} elif isinstance(x, list): return [_apply(x) for x in x] else: return x return _apply(sample) def move_to_cuda(sample): def _move_to_cuda(tensor): return tensor.cuda() return apply_to_sample(_move_to_cuda, sample) def prepare_sample(samples, cuda_enabled=True): if cuda_enabled: samples = move_to_cuda(samples) # TODO fp16 support return samples def reorg_datasets_by_split(datasets): """ Organizes datasets by split. Args: datasets: dict of torch.utils.data.Dataset objects by name. Returns: Dict of datasets by split {split_name: List[Datasets]}. """ # if len(datasets) == 1: # return datasets[list(datasets.keys())[0]] # else: reorg_datasets = dict() # reorganize by split for _, dataset in datasets.items(): for split_name, dataset_split in dataset.items(): if split_name not in reorg_datasets: reorg_datasets[split_name] = [dataset_split] else: reorg_datasets[split_name].append(dataset_split) return reorg_datasets def concat_datasets(datasets): """ Concatenates multiple datasets into a single dataset. It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support generic IterableDataset because it requires creating separate samplers. Now only supports conctenating training datasets and assuming validation and testing have only a single dataset. This is because metrics should not be computed on the concatenated datasets. Args: datasets: dict of torch.utils.data.Dataset objects by split. Returns: Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, "val" and "test" remain the same. If the input training datasets contain both map-style and DataPipeline datasets, returns a tuple, where the first element is a concatenated map-style dataset and the second element is a chained DataPipeline dataset. """ # concatenate datasets in the same split for split_name in datasets: if split_name != "train": assert ( len(datasets[split_name]) == 1 ), "Do not support multiple {} datasets.".format(split_name) datasets[split_name] = datasets[split_name][0] else: iterable_datasets, map_datasets = [], [] for dataset in datasets[split_name]: if isinstance(dataset, wds.DataPipeline): logging.info( "Dataset {} is IterableDataset, can't be concatenated.".format( dataset ) ) iterable_datasets.append(dataset) elif isinstance(dataset, IterableDataset): raise NotImplementedError( "Do not support concatenation of generic IterableDataset." ) else: map_datasets.append(dataset) # if len(iterable_datasets) > 0: # concatenate map-style datasets and iterable-style datasets separately if len(iterable_datasets) > 1: chained_datasets = ( ChainDataset(iterable_datasets) ) elif len(iterable_datasets) == 1: chained_datasets = iterable_datasets[0] else: chained_datasets = None concat_datasets = ( ConcatDataset(map_datasets) if len(map_datasets) > 0 else None ) train_datasets = concat_datasets, chained_datasets train_datasets = tuple([x for x in train_datasets if x is not None]) train_datasets = ( train_datasets[0] if len(train_datasets) == 1 else train_datasets ) datasets[split_name] = train_datasets return datasets