Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,013 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import collections
import copy
from typing import List, Sequence, Union
from mmengine.dataset import BaseDataset
from mmengine.dataset import ConcatDataset as MMENGINE_ConcatDataset
from mmengine.dataset import force_full_init
from mmdet.registry import DATASETS, TRANSFORMS
@DATASETS.register_module()
class MultiImageMixDataset:
"""A wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like
mosaic and mixup. For the augmentation pipeline of mixed image data,
the `get_indexes` method needs to be provided to obtain the image
indexes, and you can set `skip_flags` to change the pipeline running
process. At the same time, we provide the `dynamic_scale` parameter
to dynamically change the output image size.
Args:
dataset (:obj:`CustomDataset`): The dataset to be mixed.
pipeline (Sequence[dict]): Sequence of transform object or
config dict to be composed.
dynamic_scale (tuple[int], optional): The image scale can be changed
dynamically. Default to None. It is deprecated.
skip_type_keys (list[str], optional): Sequence of type string to
be skip pipeline. Default to None.
max_refetch (int): The maximum number of retry iterations for getting
valid results from the pipeline. If the number of iterations is
greater than `max_refetch`, but results is still None, then the
iteration is terminated and raise the error. Default: 15.
"""
def __init__(self,
dataset: Union[BaseDataset, dict],
pipeline: Sequence[str],
skip_type_keys: Union[Sequence[str], None] = None,
max_refetch: int = 15,
lazy_init: bool = False) -> None:
assert isinstance(pipeline, collections.abc.Sequence)
if skip_type_keys is not None:
assert all([
isinstance(skip_type_key, str)
for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys
self.pipeline = []
self.pipeline_types = []
for transform in pipeline:
if isinstance(transform, dict):
self.pipeline_types.append(transform['type'])
transform = TRANSFORMS.build(transform)
self.pipeline.append(transform)
else:
raise TypeError('pipeline must be a dict')
self.dataset: BaseDataset
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, BaseDataset):
self.dataset = dataset
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
self._metainfo = self.dataset.metainfo
if hasattr(self.dataset, 'flag'):
self.flag = self.dataset.flag
self.num_samples = len(self.dataset)
self.max_refetch = max_refetch
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def metainfo(self) -> dict:
"""Get the meta information of the multi-image-mixed dataset.
Returns:
dict: The meta information of multi-image-mixed dataset.
"""
return copy.deepcopy(self._metainfo)
def full_init(self):
"""Loop to ``full_init`` each dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
self._ori_len = len(self.dataset)
self._fully_initialized = True
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``ConcatDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
return self.dataset.get_data_info(idx)
@force_full_init
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
results = copy.deepcopy(self.dataset[idx])
for (transform, transform_type) in zip(self.pipeline,
self.pipeline_types):
if self._skip_type_keys is not None and \
transform_type in self._skip_type_keys:
continue
if hasattr(transform, 'get_indexes'):
for i in range(self.max_refetch):
# Make sure the results passed the loading pipeline
# of the original dataset is not None.
indexes = transform.get_indexes(self.dataset)
if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes]
mix_results = [
copy.deepcopy(self.dataset[index]) for index in indexes
]
if None not in mix_results:
results['mix_results'] = mix_results
break
else:
raise RuntimeError(
'The loading pipeline of the original dataset'
' always return None. Please check the correctness '
'of the dataset and its pipeline.')
for i in range(self.max_refetch):
# To confirm the results passed the training pipeline
# of the wrapper is not None.
updated_results = transform(copy.deepcopy(results))
if updated_results is not None:
results = updated_results
break
else:
raise RuntimeError(
'The training pipeline of the dataset wrapper'
' always return None.Please check the correctness '
'of the dataset and its pipeline.')
if 'mix_results' in results:
results.pop('mix_results')
return results
def update_skip_type_keys(self, skip_type_keys):
"""Update skip_type_keys. It is called by an external hook.
Args:
skip_type_keys (list[str], optional): Sequence of type
string to be skip pipeline.
"""
assert all([
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys
@DATASETS.register_module()
class ConcatDataset(MMENGINE_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as ``torch.utils.data.dataset.ConcatDataset``, support
lazy_init and get_dataset_source.
Note:
``ConcatDataset`` should not inherit from ``BaseDataset`` since
``get_subset`` and ``get_subset_`` could produce ambiguous meaning
sub-dataset which conflicts with original dataset. If you want to use
a sub-dataset of ``ConcatDataset``, you should set ``indices``
arguments for wrapped dataset which inherit from ``BaseDataset``.
Args:
datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets
which will be concatenated.
lazy_init (bool, optional): Whether to load annotation during
instantiation. Defaults to False.
ignore_keys (List[str] or str): Ignore the keys that can be
unequal in `dataset.metainfo`. Defaults to None.
`New in version 0.3.0.`
"""
def __init__(self,
datasets: Sequence[Union[BaseDataset, dict]],
lazy_init: bool = False,
ignore_keys: Union[str, List[str], None] = None):
self.datasets: List[BaseDataset] = []
for i, dataset in enumerate(datasets):
if isinstance(dataset, dict):
self.datasets.append(DATASETS.build(dataset))
elif isinstance(dataset, BaseDataset):
self.datasets.append(dataset)
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
if ignore_keys is None:
self.ignore_keys = []
elif isinstance(ignore_keys, str):
self.ignore_keys = [ignore_keys]
elif isinstance(ignore_keys, list):
self.ignore_keys = ignore_keys
else:
raise TypeError('ignore_keys should be a list or str, '
f'but got {type(ignore_keys)}')
meta_keys: set = set()
for dataset in self.datasets:
meta_keys |= dataset.metainfo.keys()
# if the metainfo of multiple datasets are the same, use metainfo
# of the first dataset, else the metainfo is a list with metainfo
# of all the datasets
is_all_same = True
self._metainfo_first = self.datasets[0].metainfo
for i, dataset in enumerate(self.datasets, 1):
for key in meta_keys:
if key in self.ignore_keys:
continue
if key not in dataset.metainfo:
is_all_same = False
break
if self._metainfo_first[key] != dataset.metainfo[key]:
is_all_same = False
break
if is_all_same:
self._metainfo = self.datasets[0].metainfo
else:
self._metainfo = [dataset.metainfo for dataset in self.datasets]
self._fully_initialized = False
if not lazy_init:
self.full_init()
def get_dataset_source(self, idx: int) -> int:
dataset_idx, _ = self._get_ori_dataset_idx(idx)
return dataset_idx
|