# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import collections import copy from typing import List, Sequence, Union from mmengine.dataset import BaseDataset from mmengine.dataset import ConcatDataset as MMENGINE_ConcatDataset from mmengine.dataset import force_full_init from mmdet.registry import DATASETS, TRANSFORMS @DATASETS.register_module() class MultiImageMixDataset: """A wrapper of multiple images mixed dataset. Suitable for training on multiple images mixed data augmentation like mosaic and mixup. For the augmentation pipeline of mixed image data, the `get_indexes` method needs to be provided to obtain the image indexes, and you can set `skip_flags` to change the pipeline running process. At the same time, we provide the `dynamic_scale` parameter to dynamically change the output image size. Args: dataset (:obj:`CustomDataset`): The dataset to be mixed. pipeline (Sequence[dict]): Sequence of transform object or config dict to be composed. dynamic_scale (tuple[int], optional): The image scale can be changed dynamically. Default to None. It is deprecated. skip_type_keys (list[str], optional): Sequence of type string to be skip pipeline. Default to None. max_refetch (int): The maximum number of retry iterations for getting valid results from the pipeline. If the number of iterations is greater than `max_refetch`, but results is still None, then the iteration is terminated and raise the error. Default: 15. """ def __init__(self, dataset: Union[BaseDataset, dict], pipeline: Sequence[str], skip_type_keys: Union[Sequence[str], None] = None, max_refetch: int = 15, lazy_init: bool = False) -> None: assert isinstance(pipeline, collections.abc.Sequence) if skip_type_keys is not None: assert all([ isinstance(skip_type_key, str) for skip_type_key in skip_type_keys ]) self._skip_type_keys = skip_type_keys self.pipeline = [] self.pipeline_types = [] for transform in pipeline: if isinstance(transform, dict): self.pipeline_types.append(transform['type']) transform = TRANSFORMS.build(transform) self.pipeline.append(transform) else: raise TypeError('pipeline must be a dict') self.dataset: BaseDataset if isinstance(dataset, dict): self.dataset = DATASETS.build(dataset) elif isinstance(dataset, BaseDataset): self.dataset = dataset else: raise TypeError( 'elements in datasets sequence should be config or ' f'`BaseDataset` instance, but got {type(dataset)}') self._metainfo = self.dataset.metainfo if hasattr(self.dataset, 'flag'): self.flag = self.dataset.flag self.num_samples = len(self.dataset) self.max_refetch = max_refetch self._fully_initialized = False if not lazy_init: self.full_init() @property def metainfo(self) -> dict: """Get the meta information of the multi-image-mixed dataset. Returns: dict: The meta information of multi-image-mixed dataset. """ return copy.deepcopy(self._metainfo) def full_init(self): """Loop to ``full_init`` each dataset.""" if self._fully_initialized: return self.dataset.full_init() self._ori_len = len(self.dataset) self._fully_initialized = True @force_full_init def get_data_info(self, idx: int) -> dict: """Get annotation by index. Args: idx (int): Global index of ``ConcatDataset``. Returns: dict: The idx-th annotation of the datasets. """ return self.dataset.get_data_info(idx) @force_full_init def __len__(self): return self.num_samples def __getitem__(self, idx): results = copy.deepcopy(self.dataset[idx]) for (transform, transform_type) in zip(self.pipeline, self.pipeline_types): if self._skip_type_keys is not None and \ transform_type in self._skip_type_keys: continue if hasattr(transform, 'get_indexes'): for i in range(self.max_refetch): # Make sure the results passed the loading pipeline # of the original dataset is not None. indexes = transform.get_indexes(self.dataset) if not isinstance(indexes, collections.abc.Sequence): indexes = [indexes] mix_results = [ copy.deepcopy(self.dataset[index]) for index in indexes ] if None not in mix_results: results['mix_results'] = mix_results break else: raise RuntimeError( 'The loading pipeline of the original dataset' ' always return None. Please check the correctness ' 'of the dataset and its pipeline.') for i in range(self.max_refetch): # To confirm the results passed the training pipeline # of the wrapper is not None. updated_results = transform(copy.deepcopy(results)) if updated_results is not None: results = updated_results break else: raise RuntimeError( 'The training pipeline of the dataset wrapper' ' always return None.Please check the correctness ' 'of the dataset and its pipeline.') if 'mix_results' in results: results.pop('mix_results') return results def update_skip_type_keys(self, skip_type_keys): """Update skip_type_keys. It is called by an external hook. Args: skip_type_keys (list[str], optional): Sequence of type string to be skip pipeline. """ assert all([ isinstance(skip_type_key, str) for skip_type_key in skip_type_keys ]) self._skip_type_keys = skip_type_keys @DATASETS.register_module() class ConcatDataset(MMENGINE_ConcatDataset): """A wrapper of concatenated dataset. Same as ``torch.utils.data.dataset.ConcatDataset``, support lazy_init and get_dataset_source. Note: ``ConcatDataset`` should not inherit from ``BaseDataset`` since ``get_subset`` and ``get_subset_`` could produce ambiguous meaning sub-dataset which conflicts with original dataset. If you want to use a sub-dataset of ``ConcatDataset``, you should set ``indices`` arguments for wrapped dataset which inherit from ``BaseDataset``. Args: datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets which will be concatenated. lazy_init (bool, optional): Whether to load annotation during instantiation. Defaults to False. ignore_keys (List[str] or str): Ignore the keys that can be unequal in `dataset.metainfo`. Defaults to None. `New in version 0.3.0.` """ def __init__(self, datasets: Sequence[Union[BaseDataset, dict]], lazy_init: bool = False, ignore_keys: Union[str, List[str], None] = None): self.datasets: List[BaseDataset] = [] for i, dataset in enumerate(datasets): if isinstance(dataset, dict): self.datasets.append(DATASETS.build(dataset)) elif isinstance(dataset, BaseDataset): self.datasets.append(dataset) else: raise TypeError( 'elements in datasets sequence should be config or ' f'`BaseDataset` instance, but got {type(dataset)}') if ignore_keys is None: self.ignore_keys = [] elif isinstance(ignore_keys, str): self.ignore_keys = [ignore_keys] elif isinstance(ignore_keys, list): self.ignore_keys = ignore_keys else: raise TypeError('ignore_keys should be a list or str, ' f'but got {type(ignore_keys)}') meta_keys: set = set() for dataset in self.datasets: meta_keys |= dataset.metainfo.keys() # if the metainfo of multiple datasets are the same, use metainfo # of the first dataset, else the metainfo is a list with metainfo # of all the datasets is_all_same = True self._metainfo_first = self.datasets[0].metainfo for i, dataset in enumerate(self.datasets, 1): for key in meta_keys: if key in self.ignore_keys: continue if key not in dataset.metainfo: is_all_same = False break if self._metainfo_first[key] != dataset.metainfo[key]: is_all_same = False break if is_all_same: self._metainfo = self.datasets[0].metainfo else: self._metainfo = [dataset.metainfo for dataset in self.datasets] self._fully_initialized = False if not lazy_init: self.full_init() def get_dataset_source(self, idx: int) -> int: dataset_idx, _ = self._get_ori_dataset_idx(idx) return dataset_idx