Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import collections | |
import copy | |
from typing import List, Sequence, Union | |
from mmengine.dataset import BaseDataset | |
from mmengine.dataset import ConcatDataset as MMENGINE_ConcatDataset | |
from mmengine.dataset import force_full_init | |
from mmdet.registry import DATASETS, TRANSFORMS | |
class MultiImageMixDataset: | |
"""A wrapper of multiple images mixed dataset. | |
Suitable for training on multiple images mixed data augmentation like | |
mosaic and mixup. For the augmentation pipeline of mixed image data, | |
the `get_indexes` method needs to be provided to obtain the image | |
indexes, and you can set `skip_flags` to change the pipeline running | |
process. At the same time, we provide the `dynamic_scale` parameter | |
to dynamically change the output image size. | |
Args: | |
dataset (:obj:`CustomDataset`): The dataset to be mixed. | |
pipeline (Sequence[dict]): Sequence of transform object or | |
config dict to be composed. | |
dynamic_scale (tuple[int], optional): The image scale can be changed | |
dynamically. Default to None. It is deprecated. | |
skip_type_keys (list[str], optional): Sequence of type string to | |
be skip pipeline. Default to None. | |
max_refetch (int): The maximum number of retry iterations for getting | |
valid results from the pipeline. If the number of iterations is | |
greater than `max_refetch`, but results is still None, then the | |
iteration is terminated and raise the error. Default: 15. | |
""" | |
def __init__(self, | |
dataset: Union[BaseDataset, dict], | |
pipeline: Sequence[str], | |
skip_type_keys: Union[Sequence[str], None] = None, | |
max_refetch: int = 15, | |
lazy_init: bool = False) -> None: | |
assert isinstance(pipeline, collections.abc.Sequence) | |
if skip_type_keys is not None: | |
assert all([ | |
isinstance(skip_type_key, str) | |
for skip_type_key in skip_type_keys | |
]) | |
self._skip_type_keys = skip_type_keys | |
self.pipeline = [] | |
self.pipeline_types = [] | |
for transform in pipeline: | |
if isinstance(transform, dict): | |
self.pipeline_types.append(transform['type']) | |
transform = TRANSFORMS.build(transform) | |
self.pipeline.append(transform) | |
else: | |
raise TypeError('pipeline must be a dict') | |
self.dataset: BaseDataset | |
if isinstance(dataset, dict): | |
self.dataset = DATASETS.build(dataset) | |
elif isinstance(dataset, BaseDataset): | |
self.dataset = dataset | |
else: | |
raise TypeError( | |
'elements in datasets sequence should be config or ' | |
f'`BaseDataset` instance, but got {type(dataset)}') | |
self._metainfo = self.dataset.metainfo | |
if hasattr(self.dataset, 'flag'): | |
self.flag = self.dataset.flag | |
self.num_samples = len(self.dataset) | |
self.max_refetch = max_refetch | |
self._fully_initialized = False | |
if not lazy_init: | |
self.full_init() | |
def metainfo(self) -> dict: | |
"""Get the meta information of the multi-image-mixed dataset. | |
Returns: | |
dict: The meta information of multi-image-mixed dataset. | |
""" | |
return copy.deepcopy(self._metainfo) | |
def full_init(self): | |
"""Loop to ``full_init`` each dataset.""" | |
if self._fully_initialized: | |
return | |
self.dataset.full_init() | |
self._ori_len = len(self.dataset) | |
self._fully_initialized = True | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index. | |
Args: | |
idx (int): Global index of ``ConcatDataset``. | |
Returns: | |
dict: The idx-th annotation of the datasets. | |
""" | |
return self.dataset.get_data_info(idx) | |
def __len__(self): | |
return self.num_samples | |
def __getitem__(self, idx): | |
results = copy.deepcopy(self.dataset[idx]) | |
for (transform, transform_type) in zip(self.pipeline, | |
self.pipeline_types): | |
if self._skip_type_keys is not None and \ | |
transform_type in self._skip_type_keys: | |
continue | |
if hasattr(transform, 'get_indexes'): | |
for i in range(self.max_refetch): | |
# Make sure the results passed the loading pipeline | |
# of the original dataset is not None. | |
indexes = transform.get_indexes(self.dataset) | |
if not isinstance(indexes, collections.abc.Sequence): | |
indexes = [indexes] | |
mix_results = [ | |
copy.deepcopy(self.dataset[index]) for index in indexes | |
] | |
if None not in mix_results: | |
results['mix_results'] = mix_results | |
break | |
else: | |
raise RuntimeError( | |
'The loading pipeline of the original dataset' | |
' always return None. Please check the correctness ' | |
'of the dataset and its pipeline.') | |
for i in range(self.max_refetch): | |
# To confirm the results passed the training pipeline | |
# of the wrapper is not None. | |
updated_results = transform(copy.deepcopy(results)) | |
if updated_results is not None: | |
results = updated_results | |
break | |
else: | |
raise RuntimeError( | |
'The training pipeline of the dataset wrapper' | |
' always return None.Please check the correctness ' | |
'of the dataset and its pipeline.') | |
if 'mix_results' in results: | |
results.pop('mix_results') | |
return results | |
def update_skip_type_keys(self, skip_type_keys): | |
"""Update skip_type_keys. It is called by an external hook. | |
Args: | |
skip_type_keys (list[str], optional): Sequence of type | |
string to be skip pipeline. | |
""" | |
assert all([ | |
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys | |
]) | |
self._skip_type_keys = skip_type_keys | |
class ConcatDataset(MMENGINE_ConcatDataset): | |
"""A wrapper of concatenated dataset. | |
Same as ``torch.utils.data.dataset.ConcatDataset``, support | |
lazy_init and get_dataset_source. | |
Note: | |
``ConcatDataset`` should not inherit from ``BaseDataset`` since | |
``get_subset`` and ``get_subset_`` could produce ambiguous meaning | |
sub-dataset which conflicts with original dataset. If you want to use | |
a sub-dataset of ``ConcatDataset``, you should set ``indices`` | |
arguments for wrapped dataset which inherit from ``BaseDataset``. | |
Args: | |
datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets | |
which will be concatenated. | |
lazy_init (bool, optional): Whether to load annotation during | |
instantiation. Defaults to False. | |
ignore_keys (List[str] or str): Ignore the keys that can be | |
unequal in `dataset.metainfo`. Defaults to None. | |
`New in version 0.3.0.` | |
""" | |
def __init__(self, | |
datasets: Sequence[Union[BaseDataset, dict]], | |
lazy_init: bool = False, | |
ignore_keys: Union[str, List[str], None] = None): | |
self.datasets: List[BaseDataset] = [] | |
for i, dataset in enumerate(datasets): | |
if isinstance(dataset, dict): | |
self.datasets.append(DATASETS.build(dataset)) | |
elif isinstance(dataset, BaseDataset): | |
self.datasets.append(dataset) | |
else: | |
raise TypeError( | |
'elements in datasets sequence should be config or ' | |
f'`BaseDataset` instance, but got {type(dataset)}') | |
if ignore_keys is None: | |
self.ignore_keys = [] | |
elif isinstance(ignore_keys, str): | |
self.ignore_keys = [ignore_keys] | |
elif isinstance(ignore_keys, list): | |
self.ignore_keys = ignore_keys | |
else: | |
raise TypeError('ignore_keys should be a list or str, ' | |
f'but got {type(ignore_keys)}') | |
meta_keys: set = set() | |
for dataset in self.datasets: | |
meta_keys |= dataset.metainfo.keys() | |
# if the metainfo of multiple datasets are the same, use metainfo | |
# of the first dataset, else the metainfo is a list with metainfo | |
# of all the datasets | |
is_all_same = True | |
self._metainfo_first = self.datasets[0].metainfo | |
for i, dataset in enumerate(self.datasets, 1): | |
for key in meta_keys: | |
if key in self.ignore_keys: | |
continue | |
if key not in dataset.metainfo: | |
is_all_same = False | |
break | |
if self._metainfo_first[key] != dataset.metainfo[key]: | |
is_all_same = False | |
break | |
if is_all_same: | |
self._metainfo = self.datasets[0].metainfo | |
else: | |
self._metainfo = [dataset.metainfo for dataset in self.datasets] | |
self._fully_initialized = False | |
if not lazy_init: | |
self.full_init() | |
def get_dataset_source(self, idx: int) -> int: | |
dataset_idx, _ = self._get_ori_dataset_idx(idx) | |
return dataset_idx | |