Spaces:
Runtime error
Runtime error
import collections | |
import copy | |
import random | |
from typing import List, Sequence, Union | |
import numpy as np | |
from mmdet.datasets.base_det_dataset import BaseDetDataset | |
from mmdet.datasets.base_video_dataset import BaseVideoDataset | |
from mmdet.registry import DATASETS, TRANSFORMS | |
from mmengine.dataset import BaseDataset, force_full_init | |
from .rsconcat_dataset import RandomSampleJointVideoConcatDataset | |
class SeqMultiImageMixDataset: | |
"""A wrapper of multiple images mixed dataset. | |
Suitable for training on multiple images mixed data augmentation like | |
mosaic and mixup. For the augmentation pipeline of mixed image data, | |
the `get_indexes` method needs to be provided to obtain the image | |
indexes, and you can set `skip_flags` to change the pipeline running | |
process. At the same time, we provide the `dynamic_scale` parameter | |
to dynamically change the output image size. | |
Args: | |
dataset (:obj:`CustomDataset`): The dataset to be mixed. | |
pipeline (Sequence[dict]): Sequence of transform object or | |
config dict to be composed. | |
dynamic_scale (tuple[int], optional): The image scale can be changed | |
dynamically. Default to None. It is deprecated. | |
skip_type_keys (list[str], optional): Sequence of type string to | |
be skip pipeline. Default to None. | |
max_refetch (int): The maximum number of retry iterations for getting | |
valid results from the pipeline. If the number of iterations is | |
greater than `max_refetch`, but results is still None, then the | |
iteration is terminated and raise the error. Default: 15. | |
""" | |
def __init__( | |
self, | |
dataset: Union[BaseDataset, dict], | |
pipeline: Sequence[str], | |
skip_type_keys: Union[Sequence[str], None] = None, | |
max_refetch: int = 15, | |
lazy_init: bool = False, | |
) -> None: | |
assert isinstance(pipeline, collections.abc.Sequence) | |
if skip_type_keys is not None: | |
assert all( | |
[isinstance(skip_type_key, str) for skip_type_key in skip_type_keys] | |
) | |
self._skip_type_keys = skip_type_keys | |
self.pipeline = [] | |
self.pipeline_types = [] | |
for transform in pipeline: | |
if isinstance(transform, dict): | |
self.pipeline_types.append(transform["type"]) | |
transform = TRANSFORMS.build(transform) | |
self.pipeline.append(transform) | |
else: | |
raise TypeError("pipeline must be a dict") | |
self.dataset: BaseDataset | |
if isinstance(dataset, dict): | |
self.dataset = DATASETS.build(dataset) | |
elif isinstance(dataset, BaseDataset): | |
self.dataset = dataset | |
else: | |
raise TypeError( | |
"elements in datasets sequence should be config or " | |
f"`BaseDataset` instance, but got {type(dataset)}" | |
) | |
self._metainfo = self.dataset.metainfo | |
if hasattr(self.dataset, "flag"): | |
self.flag = self.dataset.flag | |
self.num_samples = len(self.dataset) | |
self.max_refetch = max_refetch | |
self._fully_initialized = False | |
if not lazy_init: | |
self.full_init() | |
self.generate_indices() | |
def generate_indices(self): | |
cat_datasets = self.dataset.datasets | |
for dataset in cat_datasets: | |
self.test_mode = dataset.test_mode | |
assert not self.test_mode, "'ConcatDataset' should not exist in " | |
"test mode" | |
video_indices = [] | |
img_indices = [] | |
if isinstance(dataset, BaseVideoDataset): | |
num_videos = len(dataset) | |
for video_ind in range(num_videos): | |
video_indices.extend( | |
[ | |
(video_ind, frame_ind) | |
for frame_ind in range(dataset.get_len_per_video(video_ind)) | |
] | |
) | |
elif isinstance(dataset, BaseDetDataset): | |
num_imgs = len(dataset) | |
for img_ind in range(num_imgs): | |
img_indices.extend([img_ind]) | |
###### special process to make debug task easier ##### | |
def alternate_merge(list1, list2): | |
# Create a new list to hold the merged elements | |
merged_list = [] | |
# Get the length of the shorter list | |
min_length = min(len(list1), len(list2)) | |
# Append elements alternately from both lists | |
for i in range(min_length): | |
merged_list.append(list1[i]) | |
merged_list.append(list2[i]) | |
# Append the remaining elements from the longer list | |
if len(list1) > len(list2): | |
merged_list.extend(list1[min_length:]) | |
else: | |
merged_list.extend(list2[min_length:]) | |
return merged_list | |
self.indices = alternate_merge(img_indices, video_indices) | |
def metainfo(self) -> dict: | |
"""Get the meta information of the multi-image-mixed dataset. | |
Returns: | |
dict: The meta information of multi-image-mixed dataset. | |
""" | |
return copy.deepcopy(self._metainfo) | |
def full_init(self): | |
"""Loop to ``full_init`` each dataset.""" | |
if self._fully_initialized: | |
return | |
self.dataset.full_init() | |
self._ori_len = len(self.dataset) | |
self._fully_initialized = True | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index. | |
Args: | |
idx (int): Global index of ``ConcatDataset``. | |
Returns: | |
dict: The idx-th annotation of the datasets. | |
""" | |
return self.dataset.get_data_info(idx) | |
def get_transform_indexes(self, transform, results, t_type="SeqMosaic"): | |
num_samples = len(results["img_id"]) | |
for i in range(self.max_refetch): | |
# Make sure the results passed the loading pipeline | |
# of the original dataset is not None. | |
indexes = transform.get_indexes(self.dataset) | |
if not isinstance(indexes, collections.abc.Sequence): | |
indexes = [indexes] | |
mix_results = [copy.deepcopy(self.dataset[index]) for index in indexes] | |
if None not in mix_results: | |
if t_type == "SeqMosaic": | |
results["mosaic_mix_results"] = [mix_results] * num_samples | |
elif t_type == "SeqMixUp": | |
results["mixup_mix_results"] = [mix_results] * num_samples | |
elif t_type == "SeqCopyPaste": | |
results["copypaste_mix_results"] = [mix_results] * num_samples | |
return results | |
else: | |
raise RuntimeError( | |
"The loading pipeline of the original dataset" | |
" always return None. Please check the correctness " | |
"of the dataset and its pipeline." | |
) | |
def __len__(self): | |
return self.num_samples | |
def __getitem__(self, idx): | |
while True: | |
results = copy.deepcopy(self.dataset[idx]) | |
for (transform, transform_type) in zip(self.pipeline, self.pipeline_types): | |
if ( | |
self._skip_type_keys is not None | |
and transform_type in self._skip_type_keys | |
): | |
continue | |
if transform_type == "MasaTransformBroadcaster": | |
for sub_transform in transform.transforms: | |
if hasattr(sub_transform, "get_indexes"): | |
sub_transform_type = type(sub_transform).__name__ | |
results = self.get_transform_indexes( | |
sub_transform, results, sub_transform_type | |
) | |
elif hasattr(transform, "get_indexes"): | |
for i in range(self.max_refetch): | |
# Make sure the results passed the loading pipeline | |
# of the original dataset is not None. | |
indexes = transform.get_indexes(self.dataset) | |
if not isinstance(indexes, collections.abc.Sequence): | |
indexes = [indexes] | |
mix_results = [ | |
copy.deepcopy(self.dataset[index]) for index in indexes | |
] | |
if None not in mix_results: | |
results["mix_results"] = mix_results | |
break | |
else: | |
raise RuntimeError( | |
"The loading pipeline of the original dataset" | |
" always return None. Please check the correctness " | |
"of the dataset and its pipeline." | |
) | |
for i in range(self.max_refetch): | |
# To confirm the results passed the training pipeline | |
# of the wrapper is not None. | |
try: | |
updated_results = transform(copy.deepcopy(results)) | |
except Exception as e: | |
print( | |
"Error occurred while running pipeline", | |
f"{transform} with error: {e}", | |
) | |
# print('Empty instances due to augmentation, re-sampling...') | |
idx = self._rand_another(idx) | |
continue | |
if updated_results is not None: | |
results = updated_results | |
break | |
else: | |
raise RuntimeError( | |
"The training pipeline of the dataset wrapper" | |
" always return None.Please check the correctness " | |
"of the dataset and its pipeline." | |
) | |
if "mosaic_mix_results" in results: | |
results.pop("mosaic_mix_results") | |
if "mixup_mix_results" in results: | |
results.pop("mixup_mix_results") | |
if "copypaste_mix_results" in results: | |
results.pop("copypaste_mix_results") | |
return results | |
def update_skip_type_keys(self, skip_type_keys): | |
"""Update skip_type_keys. It is called by an external hook. | |
Args: | |
skip_type_keys (list[str], optional): Sequence of type | |
string to be skip pipeline. | |
""" | |
assert all([isinstance(skip_type_key, str) for skip_type_key in skip_type_keys]) | |
self._skip_type_keys = skip_type_keys | |
def _rand_another(self, idx): | |
"""Get another random index from the same group as the given index.""" | |
return np.random.choice(self.indices) | |
class SeqRandomMultiImageVideoMixDataset(SeqMultiImageMixDataset): | |
def __init__( | |
self, video_pipeline: Sequence[str], video_sample_ratio=0.5, *args, **kwargs | |
): | |
super().__init__(*args, **kwargs) | |
self.video_pipeline = [] | |
self.video_pipeline_types = [] | |
for transform in video_pipeline: | |
if isinstance(transform, dict): | |
self.video_pipeline_types.append(transform["type"]) | |
transform = TRANSFORMS.build(transform) | |
self.video_pipeline.append(transform) | |
else: | |
raise TypeError("pipeline must be a dict") | |
self.video_sample_ratio = video_sample_ratio | |
assert isinstance(self.dataset, RandomSampleJointVideoConcatDataset) | |
def get_transform_indexes( | |
self, transform, results, sample_video, t_type="SeqMosaic" | |
): | |
num_samples = len(results["img_id"]) | |
for i in range(self.max_refetch): | |
# Make sure the results passed the loading pipeline | |
# of the original dataset is not None. | |
indexes = transform.get_indexes(self.dataset.datasets[0]) | |
if not isinstance(indexes, collections.abc.Sequence): | |
indexes = [indexes] | |
if sample_video: | |
mix_results = [copy.deepcopy(self.dataset[0]) for index in indexes] | |
else: | |
mix_results = [copy.deepcopy(self.dataset[1]) for index in indexes] | |
if None not in mix_results: | |
if t_type == "SeqMosaic": | |
results["mosaic_mix_results"] = [mix_results] * num_samples | |
elif t_type == "SeqMixUp": | |
results["mixup_mix_results"] = [mix_results] * num_samples | |
elif t_type == "SeqCopyPaste": | |
results["copypaste_mix_results"] = [mix_results] * num_samples | |
return results | |
else: | |
raise RuntimeError( | |
"The loading pipeline of the original dataset" | |
" always return None. Please check the correctness " | |
"of the dataset and its pipeline." | |
) | |
def __getitem__(self, idx): | |
while True: | |
if random.random() < self.video_sample_ratio: | |
sample_video = True | |
else: | |
sample_video = False | |
if sample_video: | |
results = copy.deepcopy(self.dataset[0]) | |
pipeline = self.video_pipeline | |
pipeline_type = self.video_pipeline_types | |
else: | |
results = copy.deepcopy(self.dataset[1]) | |
pipeline = self.pipeline | |
pipeline_type = self.pipeline_types | |
# if results['img_id'][0] != results['img_id'][1]: | |
# self.update_skip_type_keys(['SeqMosaic', 'SeqMixUp']) | |
# else: | |
# self._skip_type_keys = None | |
for (transform, transform_type) in zip(pipeline, pipeline_type): | |
if ( | |
self._skip_type_keys is not None | |
and transform_type in self._skip_type_keys | |
): | |
continue | |
if transform_type == "MasaTransformBroadcaster": | |
for sub_transform in transform.transforms: | |
if hasattr(sub_transform, "get_indexes"): | |
sub_transform_type = type(sub_transform).__name__ | |
results = self.get_transform_indexes( | |
sub_transform, results, sample_video, sub_transform_type | |
) | |
elif hasattr(transform, "get_indexes"): | |
for i in range(self.max_refetch): | |
# Make sure the results passed the loading pipeline | |
# of the original dataset is not None. | |
indexes = transform.get_indexes(self.dataset) | |
if not isinstance(indexes, collections.abc.Sequence): | |
indexes = [indexes] | |
mix_results = [ | |
copy.deepcopy(self.dataset[index]) for index in indexes | |
] | |
if None not in mix_results: | |
results["mix_results"] = mix_results | |
break | |
else: | |
raise RuntimeError( | |
"The loading pipeline of the original dataset" | |
" always return None. Please check the correctness " | |
"of the dataset and its pipeline." | |
) | |
for i in range(self.max_refetch): | |
# To confirm the results passed the training pipeline | |
# of the wrapper is not None. | |
try: | |
updated_results = transform(copy.deepcopy(results)) | |
except Exception as e: | |
print( | |
"Error occurred while running pipeline", | |
f"{transform} with error: {e}", | |
) | |
# print('Empty instances due to augmentation, re-sampling...') | |
# idx = self._rand_another(idx) | |
continue | |
if updated_results is not None: | |
results = updated_results | |
break | |
else: | |
raise RuntimeError( | |
"The training pipeline of the dataset wrapper" | |
" always return None.Please check the correctness " | |
"of the dataset and its pipeline." | |
) | |
if "mosaic_mix_results" in results: | |
results.pop("mosaic_mix_results") | |
if "mixup_mix_results" in results: | |
results.pop("mixup_mix_results") | |
if "copypaste_mix_results" in results: | |
results.pop("copypaste_mix_results") | |
return results | |