Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import copy | |
import functools | |
import gc | |
import logging | |
import pickle | |
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union | |
import numpy as np | |
from torch.utils.data import Dataset | |
from mmengine.fileio import join_path, list_from_file, load | |
from mmengine.logging import print_log | |
from mmengine.registry import TRANSFORMS | |
from mmengine.utils import is_abs | |
class Compose: | |
"""Compose multiple transforms sequentially. | |
Args: | |
transforms (Sequence[dict, callable], optional): Sequence of transform | |
object or config dict to be composed. | |
""" | |
def __init__(self, transforms: Optional[Sequence[Union[dict, Callable]]]): | |
self.transforms: List[Callable] = [] | |
if transforms is None: | |
transforms = [] | |
for transform in transforms: | |
# `Compose` can be built with config dict with type and | |
# corresponding arguments. | |
if isinstance(transform, dict): | |
transform = TRANSFORMS.build(transform) | |
if not callable(transform): | |
raise TypeError(f'transform should be a callable object, ' | |
f'but got {type(transform)}') | |
self.transforms.append(transform) | |
elif callable(transform): | |
self.transforms.append(transform) | |
else: | |
raise TypeError( | |
f'transform must be a callable object or dict, ' | |
f'but got {type(transform)}') | |
def __call__(self, data: dict) -> Optional[dict]: | |
"""Call function to apply transforms sequentially. | |
Args: | |
data (dict): A result dict contains the data to transform. | |
Returns: | |
dict: Transformed data. | |
""" | |
for t in self.transforms: | |
data = t(data) | |
# The transform will return None when it failed to load images or | |
# cannot find suitable augmentation parameters to augment the data. | |
# Here we simply return None if the transform returns None and the | |
# dataset will handle it by randomly selecting another data sample. | |
if data is None: | |
return None | |
return data | |
def __repr__(self): | |
"""Print ``self.transforms`` in sequence. | |
Returns: | |
str: Formatted string. | |
""" | |
format_string = self.__class__.__name__ + '(' | |
for t in self.transforms: | |
format_string += '\n' | |
format_string += f' {t}' | |
format_string += '\n)' | |
return format_string | |
def force_full_init(old_func: Callable) -> Any: | |
"""Those methods decorated by ``force_full_init`` will be forced to call | |
``full_init`` if the instance has not been fully initiated. | |
Args: | |
old_func (Callable): Decorated function, make sure the first arg is an | |
instance with ``full_init`` method. | |
Returns: | |
Any: Depends on old_func. | |
""" | |
def wrapper(obj: object, *args, **kwargs): | |
# The instance must have `full_init` method. | |
if not hasattr(obj, 'full_init'): | |
raise AttributeError(f'{type(obj)} does not have full_init ' | |
'method.') | |
# If instance does not have `_fully_initialized` attribute or | |
# `_fully_initialized` is False, call `full_init` and set | |
# `_fully_initialized` to True | |
if not getattr(obj, '_fully_initialized', False): | |
print_log( | |
f'Attribute `_fully_initialized` is not defined in ' | |
f'{type(obj)} or `type(obj)._fully_initialized is ' | |
'False, `full_init` will be called and ' | |
f'{type(obj)}._fully_initialized will be set to True', | |
logger='current', | |
level=logging.WARNING) | |
obj.full_init() # type: ignore | |
obj._fully_initialized = True # type: ignore | |
return old_func(obj, *args, **kwargs) | |
return wrapper | |
class BaseDataset(Dataset): | |
r"""BaseDataset for open source projects in OpenMMLab. | |
The annotation format is shown as follows. | |
.. code-block:: none | |
{ | |
"metainfo": | |
{ | |
"dataset_type": "test_dataset", | |
"task_name": "test_task" | |
}, | |
"data_list": | |
[ | |
{ | |
"img_path": "test_img.jpg", | |
"height": 604, | |
"width": 640, | |
"instances": | |
[ | |
{ | |
"bbox": [0, 0, 10, 20], | |
"bbox_label": 1, | |
"mask": [[0,0],[0,10],[10,20],[20,0]], | |
"extra_anns": [1,2,3] | |
}, | |
{ | |
"bbox": [10, 10, 110, 120], | |
"bbox_label": 2, | |
"mask": [[10,10],[10,110],[110,120],[120,10]], | |
"extra_anns": [4,5,6] | |
} | |
] | |
}, | |
] | |
} | |
Args: | |
ann_file (str, optional): Annotation file path. Defaults to ''. | |
metainfo (dict, optional): Meta information for dataset, such as class | |
information. Defaults to None. | |
data_root (str, optional): The root directory for ``data_prefix`` and | |
``ann_file``. Defaults to ''. | |
data_prefix (dict): Prefix for training data. Defaults to | |
dict(img_path=''). | |
filter_cfg (dict, optional): Config for filter data. Defaults to None. | |
indices (int or Sequence[int], optional): Support using first few | |
data in annotation file to facilitate training/testing on a smaller | |
serialize_data (bool, optional): Whether to hold memory using | |
serialized objects, when enabled, data loader workers can use | |
shared RAM from master process instead of making a copy. Defaults | |
to True. | |
pipeline (list, optional): Processing pipeline. Defaults to []. | |
test_mode (bool, optional): ``test_mode=True`` means in test phase. | |
Defaults to False. | |
lazy_init (bool, optional): Whether to load annotation during | |
instantiation. In some cases, such as visualization, only the meta | |
information of the dataset is needed, which is not necessary to | |
load annotation file. ``Basedataset`` can skip load annotations to | |
save time by set ``lazy_init=True``. Defaults to False. | |
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a | |
None img. The maximum extra number of cycles to get a valid | |
image. Defaults to 1000. | |
Note: | |
BaseDataset collects meta information from ``annotation file`` (the | |
lowest priority), ``BaseDataset.METAINFO``(medium) and ``metainfo | |
parameter`` (highest) passed to constructors. The lower priority meta | |
information will be overwritten by higher one. | |
Note: | |
Dataset wrapper such as ``ConcatDataset``, ``RepeatDataset`` .etc. | |
should not inherit from ``BaseDataset`` since ``get_subset`` and | |
``get_subset_`` could produce ambiguous meaning sub-dataset which | |
conflicts with original dataset. | |
Examples: | |
>>> # Assume the annotation file is given above. | |
>>> class CustomDataset(BaseDataset): | |
>>> METAINFO: dict = dict(task_name='custom_task', | |
>>> dataset_type='custom_type') | |
>>> metainfo=dict(task_name='custom_task_name') | |
>>> custom_dataset = CustomDataset( | |
>>> 'path/to/ann_file', | |
>>> metainfo=metainfo) | |
>>> # meta information of annotation file will be overwritten by | |
>>> # `CustomDataset.METAINFO`. The merged meta information will | |
>>> # further be overwritten by argument `metainfo`. | |
>>> custom_dataset.metainfo | |
{'task_name': custom_task_name, dataset_type: custom_type} | |
""" | |
METAINFO: dict = dict() | |
_fully_initialized: bool = False | |
def __init__(self, | |
ann_file: Optional[str] = '', | |
metainfo: Optional[dict] = None, | |
data_root: Optional[str] = '', | |
data_prefix: dict = dict(img_path=''), | |
filter_cfg: Optional[dict] = None, | |
indices: Optional[Union[int, Sequence[int]]] = None, | |
serialize_data: bool = True, | |
pipeline: List[Union[dict, Callable]] = [], | |
test_mode: bool = False, | |
lazy_init: bool = False, | |
max_refetch: int = 1000): | |
self.ann_file = ann_file | |
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) | |
self.data_root = data_root | |
self.data_prefix = copy.copy(data_prefix) | |
self.filter_cfg = copy.deepcopy(filter_cfg) | |
self._indices = indices | |
self.serialize_data = serialize_data | |
self.test_mode = test_mode | |
self.max_refetch = max_refetch | |
self.data_list: List[dict] = [] | |
self.data_bytes: np.ndarray | |
# Join paths. | |
self._join_prefix() | |
# Build pipeline. | |
self.pipeline = Compose(pipeline) | |
# Full initialize the dataset. | |
if not lazy_init: | |
self.full_init() | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index and automatically call ``full_init`` if the | |
dataset has not been fully initialized. | |
Args: | |
idx (int): The index of data. | |
Returns: | |
dict: The idx-th annotation of the dataset. | |
""" | |
if self.serialize_data: | |
start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() | |
end_addr = self.data_address[idx].item() | |
bytes = memoryview( | |
self.data_bytes[start_addr:end_addr]) # type: ignore | |
data_info = pickle.loads(bytes) # type: ignore | |
else: | |
data_info = copy.deepcopy(self.data_list[idx]) | |
# Some codebase needs `sample_idx` of data information. Here we convert | |
# the idx to a positive number and save it in data information. | |
if idx >= 0: | |
data_info['sample_idx'] = idx | |
else: | |
data_info['sample_idx'] = len(self) + idx | |
return data_info | |
def full_init(self): | |
"""Load annotation file and set ``BaseDataset._fully_initialized`` to | |
True. | |
If ``lazy_init=False``, ``full_init`` will be called during the | |
instantiation and ``self._fully_initialized`` will be set to True. If | |
``obj._fully_initialized=False``, the class method decorated by | |
``force_full_init`` will call ``full_init`` automatically. | |
Several steps to initialize annotation: | |
- load_data_list: Load annotations from annotation file. | |
- filter data information: Filter annotations according to | |
filter_cfg. | |
- slice_data: Slice dataset according to ``self._indices`` | |
- serialize_data: Serialize ``self.data_list`` if | |
``self.serialize_data`` is True. | |
""" | |
if self._fully_initialized: | |
return | |
# load data information | |
self.data_list = self.load_data_list() | |
# filter illegal data, such as data that has no annotations. | |
self.data_list = self.filter_data() | |
# Get subset data according to indices. | |
if self._indices is not None: | |
self.data_list = self._get_unserialized_subset(self._indices) | |
# serialize data_list | |
if self.serialize_data: | |
self.data_bytes, self.data_address = self._serialize_data() | |
self._fully_initialized = True | |
def metainfo(self) -> dict: | |
"""Get meta information of dataset. | |
Returns: | |
dict: meta information collected from ``BaseDataset.METAINFO``, | |
annotation file and metainfo argument during instantiation. | |
""" | |
return copy.deepcopy(self._metainfo) | |
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: | |
"""Parse raw annotation to target format. | |
This method should return dict or list of dict. Each dict or list | |
contains the data information of a training sample. If the protocol of | |
the sample annotations is changed, this function can be overridden to | |
update the parsing logic while keeping compatibility. | |
Args: | |
raw_data_info (dict): Raw data information load from ``ann_file`` | |
Returns: | |
list or list[dict]: Parsed annotation. | |
""" | |
for prefix_key, prefix in self.data_prefix.items(): | |
assert prefix_key in raw_data_info, ( | |
f'raw_data_info: {raw_data_info} dose not contain prefix key' | |
f'{prefix_key}, please check your data_prefix.') | |
raw_data_info[prefix_key] = join_path(prefix, | |
raw_data_info[prefix_key]) | |
return raw_data_info | |
def filter_data(self) -> List[dict]: | |
"""Filter annotations according to filter_cfg. Defaults return all | |
``data_list``. | |
If some ``data_list`` could be filtered according to specific logic, | |
the subclass should override this method. | |
Returns: | |
list[int]: Filtered results. | |
""" | |
return self.data_list | |
def get_cat_ids(self, idx: int) -> List[int]: | |
"""Get category ids by index. Dataset wrapped by ClassBalancedDataset | |
must implement this method. | |
The ``ClassBalancedDataset`` requires a subclass which implements this | |
method. | |
Args: | |
idx (int): The index of data. | |
Returns: | |
list[int]: All categories in the image of specified index. | |
""" | |
raise NotImplementedError(f'{type(self)} must implement `get_cat_ids` ' | |
'method') | |
def __getitem__(self, idx: int) -> dict: | |
"""Get the idx-th image and data information of dataset after | |
``self.pipeline``, and ``full_init`` will be called if the dataset has | |
not been fully initialized. | |
During training phase, if ``self.pipeline`` get ``None``, | |
``self._rand_another`` will be called until a valid image is fetched or | |
the maximum limit of refetech is reached. | |
Args: | |
idx (int): The index of self.data_list. | |
Returns: | |
dict: The idx-th image and data information of dataset after | |
``self.pipeline``. | |
""" | |
# Performing full initialization by calling `__getitem__` will consume | |
# extra memory. If a dataset is not fully initialized by setting | |
# `lazy_init=True` and then fed into the dataloader. Different workers | |
# will simultaneously read and parse the annotation. It will cost more | |
# time and memory, although this may work. Therefore, it is recommended | |
# to manually call `full_init` before dataset fed into dataloader to | |
# ensure all workers use shared RAM from master process. | |
if not self._fully_initialized: | |
print_log( | |
'Please call `full_init()` method manually to accelerate ' | |
'the speed.', | |
logger='current', | |
level=logging.WARNING) | |
self.full_init() | |
if self.test_mode: | |
data = self.prepare_data(idx) | |
if data is None: | |
raise Exception('Test time pipline should not get `None` ' | |
'data_sample') | |
return data | |
for _ in range(self.max_refetch + 1): | |
data = self.prepare_data(idx) | |
# Broken images or random augmentations may cause the returned data | |
# to be None | |
if data is None: | |
idx = self._rand_another() | |
continue | |
return data | |
raise Exception(f'Cannot find valid image after {self.max_refetch}! ' | |
'Please check your image path and pipeline') | |
def load_data_list(self) -> List[dict]: | |
"""Load annotations from an annotation file named as ``self.ann_file`` | |
If the annotation file does not follow `OpenMMLab 2.0 format dataset | |
<https://mmengine.readthedocs.io/en/latest/advanced_tutorials/basedataset.html>`_ . | |
The subclass must override this method for load annotations. The meta | |
information of annotation file will be overwritten :attr:`METAINFO` | |
and ``metainfo`` argument of constructor. | |
Returns: | |
list[dict]: A list of annotation. | |
""" # noqa: E501 | |
# `self.ann_file` denotes the absolute annotation file path if | |
# `self.root=None` or relative path if `self.root=/path/to/data/`. | |
annotations = load(self.ann_file) | |
if not isinstance(annotations, dict): | |
raise TypeError(f'The annotations loaded from annotation file ' | |
f'should be a dict, but got {type(annotations)}!') | |
if 'data_list' not in annotations or 'metainfo' not in annotations: | |
raise ValueError('Annotation must have data_list and metainfo ' | |
'keys') | |
metainfo = annotations['metainfo'] | |
raw_data_list = annotations['data_list'] | |
# Meta information load from annotation file will not influence the | |
# existed meta information load from `BaseDataset.METAINFO` and | |
# `metainfo` arguments defined in constructor. | |
for k, v in metainfo.items(): | |
self._metainfo.setdefault(k, v) | |
# load and parse data_infos. | |
data_list = [] | |
for raw_data_info in raw_data_list: | |
# parse raw data information to target format | |
data_info = self.parse_data_info(raw_data_info) | |
if isinstance(data_info, dict): | |
# For image tasks, `data_info` should information if single | |
# image, such as dict(img_path='xxx', width=360, ...) | |
data_list.append(data_info) | |
elif isinstance(data_info, list): | |
# For video tasks, `data_info` could contain image | |
# information of multiple frames, such as | |
# [dict(video_path='xxx', timestamps=...), | |
# dict(video_path='xxx', timestamps=...)] | |
for item in data_info: | |
if not isinstance(item, dict): | |
raise TypeError('data_info must be list of dict, but ' | |
f'got {type(item)}') | |
data_list.extend(data_info) | |
else: | |
raise TypeError('data_info should be a dict or list of dict, ' | |
f'but got {type(data_info)}') | |
return data_list | |
def _load_metainfo(cls, metainfo: dict = None) -> dict: | |
"""Collect meta information from the dictionary of meta. | |
Args: | |
metainfo (dict): Meta information dict. If ``metainfo`` | |
contains existed filename, it will be parsed by | |
``list_from_file``. | |
Returns: | |
dict: Parsed meta information. | |
""" | |
# avoid `cls.METAINFO` being overwritten by `metainfo` | |
cls_metainfo = copy.deepcopy(cls.METAINFO) | |
if metainfo is None: | |
return cls_metainfo | |
if not isinstance(metainfo, dict): | |
raise TypeError( | |
f'metainfo should be a dict, but got {type(metainfo)}') | |
for k, v in metainfo.items(): | |
if isinstance(v, str): | |
# If type of value is string, and can be loaded from | |
# corresponding backend. it means the file name of meta file. | |
try: | |
cls_metainfo[k] = list_from_file(v) | |
except (TypeError, FileNotFoundError): | |
print_log( | |
f'{v} is not a meta file, simply parsed as meta ' | |
'information', | |
logger='current', | |
level=logging.WARNING) | |
cls_metainfo[k] = v | |
else: | |
cls_metainfo[k] = v | |
return cls_metainfo | |
def _join_prefix(self): | |
"""Join ``self.data_root`` with ``self.data_prefix`` and | |
``self.ann_file``. | |
Examples: | |
>>> # self.data_prefix contains relative paths | |
>>> self.data_root = 'a/b/c' | |
>>> self.data_prefix = dict(img='d/e/') | |
>>> self.ann_file = 'f' | |
>>> self._join_prefix() | |
>>> self.data_prefix | |
dict(img='a/b/c/d/e') | |
>>> self.ann_file | |
'a/b/c/f' | |
>>> # self.data_prefix contains absolute paths | |
>>> self.data_root = 'a/b/c' | |
>>> self.data_prefix = dict(img='/d/e/') | |
>>> self.ann_file = 'f' | |
>>> self._join_prefix() | |
>>> self.data_prefix | |
dict(img='/d/e') | |
>>> self.ann_file | |
'a/b/c/f' | |
""" | |
# Automatically join annotation file path with `self.root` if | |
# `self.ann_file` is not an absolute path. | |
if self.ann_file and not is_abs(self.ann_file) and self.data_root: | |
self.ann_file = join_path(self.data_root, self.ann_file) | |
# Automatically join data directory with `self.root` if path value in | |
# `self.data_prefix` is not an absolute path. | |
for data_key, prefix in self.data_prefix.items(): | |
if not isinstance(prefix, str): | |
raise TypeError('prefix should be a string, but got ' | |
f'{type(prefix)}') | |
if not is_abs(prefix) and self.data_root: | |
self.data_prefix[data_key] = join_path(self.data_root, prefix) | |
else: | |
self.data_prefix[data_key] = prefix | |
def get_subset_(self, indices: Union[Sequence[int], int]) -> None: | |
"""The in-place version of ``get_subset`` to convert dataset to a | |
subset of original dataset. | |
This method will convert the original dataset to a subset of dataset. | |
If type of indices is int, ``get_subset_`` will return a subdataset | |
which contains the first or last few data information according to | |
indices is positive or negative. If type of indices is a sequence of | |
int, the subdataset will extract the data information according to | |
the index given in indices. | |
Examples: | |
>>> dataset = BaseDataset('path/to/ann_file') | |
>>> len(dataset) | |
100 | |
>>> dataset.get_subset_(90) | |
>>> len(dataset) | |
90 | |
>>> # if type of indices is sequence, extract the corresponding | |
>>> # index data information | |
>>> dataset.get_subset_([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) | |
>>> len(dataset) | |
10 | |
>>> dataset.get_subset_(-3) | |
>>> len(dataset) # Get the latest few data information. | |
3 | |
Args: | |
indices (int or Sequence[int]): If type of indices is int, indices | |
represents the first or last few data of dataset according to | |
indices is positive or negative. If type of indices is | |
Sequence, indices represents the target data information | |
index of dataset. | |
""" | |
# Get subset of data from serialized data or data information sequence | |
# according to `self.serialize_data`. | |
if self.serialize_data: | |
self.data_bytes, self.data_address = \ | |
self._get_serialized_subset(indices) | |
else: | |
self.data_list = self._get_unserialized_subset(indices) | |
def get_subset(self, indices: Union[Sequence[int], int]) -> 'BaseDataset': | |
"""Return a subset of dataset. | |
This method will return a subset of original dataset. If type of | |
indices is int, ``get_subset_`` will return a subdataset which | |
contains the first or last few data information according to | |
indices is positive or negative. If type of indices is a sequence of | |
int, the subdataset will extract the information according to the index | |
given in indices. | |
Examples: | |
>>> dataset = BaseDataset('path/to/ann_file') | |
>>> len(dataset) | |
100 | |
>>> subdataset = dataset.get_subset(90) | |
>>> len(sub_dataset) | |
90 | |
>>> # if type of indices is list, extract the corresponding | |
>>> # index data information | |
>>> subdataset = dataset.get_subset([0, 1, 2, 3, 4, 5, 6, 7, | |
>>> 8, 9]) | |
>>> len(sub_dataset) | |
10 | |
>>> subdataset = dataset.get_subset(-3) | |
>>> len(subdataset) # Get the latest few data information. | |
3 | |
Args: | |
indices (int or Sequence[int]): If type of indices is int, indices | |
represents the first or last few data of dataset according to | |
indices is positive or negative. If type of indices is | |
Sequence, indices represents the target data information | |
index of dataset. | |
Returns: | |
BaseDataset: A subset of dataset. | |
""" | |
# Get subset of data from serialized data or data information list | |
# according to `self.serialize_data`. Since `_get_serialized_subset` | |
# will recalculate the subset data information, | |
# `_copy_without_annotation` will copy all attributes except data | |
# information. | |
sub_dataset = self._copy_without_annotation() | |
# Get subset of dataset with serialize and unserialized data. | |
if self.serialize_data: | |
data_bytes, data_address = \ | |
self._get_serialized_subset(indices) | |
sub_dataset.data_bytes = data_bytes.copy() | |
sub_dataset.data_address = data_address.copy() | |
else: | |
data_list = self._get_unserialized_subset(indices) | |
sub_dataset.data_list = copy.deepcopy(data_list) | |
return sub_dataset | |
def _get_serialized_subset(self, indices: Union[Sequence[int], int]) \ | |
-> Tuple[np.ndarray, np.ndarray]: | |
"""Get subset of serialized data information list. | |
Args: | |
indices (int or Sequence[int]): If type of indices is int, | |
indices represents the first or last few data of serialized | |
data information list. If type of indices is Sequence, indices | |
represents the target data information index which consist of | |
subset data information. | |
Returns: | |
Tuple[np.ndarray, np.ndarray]: subset of serialized data | |
information. | |
""" | |
sub_data_bytes: Union[List, np.ndarray] | |
sub_data_address: Union[List, np.ndarray] | |
if isinstance(indices, int): | |
if indices >= 0: | |
assert indices < len(self.data_address), \ | |
f'{indices} is out of dataset length({len(self)}' | |
# Return the first few data information. | |
end_addr = self.data_address[indices - 1].item() \ | |
if indices > 0 else 0 | |
# Slicing operation of `np.ndarray` does not trigger a memory | |
# copy. | |
sub_data_bytes = self.data_bytes[:end_addr] | |
# Since the buffer size of first few data information is not | |
# changed, | |
sub_data_address = self.data_address[:indices] | |
else: | |
assert -indices <= len(self.data_address), \ | |
f'{indices} is out of dataset length({len(self)}' | |
# Return the last few data information. | |
ignored_bytes_size = self.data_address[indices - 1] | |
start_addr = self.data_address[indices - 1].item() | |
sub_data_bytes = self.data_bytes[start_addr:] | |
sub_data_address = self.data_address[indices:] | |
sub_data_address = sub_data_address - ignored_bytes_size | |
elif isinstance(indices, Sequence): | |
sub_data_bytes = [] | |
sub_data_address = [] | |
for idx in indices: | |
assert len(self) > idx >= -len(self) | |
start_addr = 0 if idx == 0 else \ | |
self.data_address[idx - 1].item() | |
end_addr = self.data_address[idx].item() | |
# Get data information by address. | |
sub_data_bytes.append(self.data_bytes[start_addr:end_addr]) | |
# Get data information size. | |
sub_data_address.append(end_addr - start_addr) | |
# Handle indices is an empty list. | |
if sub_data_bytes: | |
sub_data_bytes = np.concatenate(sub_data_bytes) | |
sub_data_address = np.cumsum(sub_data_address) | |
else: | |
sub_data_bytes = np.array([]) | |
sub_data_address = np.array([]) | |
else: | |
raise TypeError('indices should be a int or sequence of int, ' | |
f'but got {type(indices)}') | |
return sub_data_bytes, sub_data_address # type: ignore | |
def _get_unserialized_subset(self, indices: Union[Sequence[int], | |
int]) -> list: | |
"""Get subset of data information list. | |
Args: | |
indices (int or Sequence[int]): If type of indices is int, | |
indices represents the first or last few data of data | |
information. If type of indices is Sequence, indices represents | |
the target data information index which consist of subset data | |
information. | |
Returns: | |
Tuple[np.ndarray, np.ndarray]: subset of data information. | |
""" | |
if isinstance(indices, int): | |
if indices >= 0: | |
# Return the first few data information. | |
sub_data_list = self.data_list[:indices] | |
else: | |
# Return the last few data information. | |
sub_data_list = self.data_list[indices:] | |
elif isinstance(indices, Sequence): | |
# Return the data information according to given indices. | |
sub_data_list = [] | |
for idx in indices: | |
sub_data_list.append(self.data_list[idx]) | |
else: | |
raise TypeError('indices should be a int or sequence of int, ' | |
f'but got {type(indices)}') | |
return sub_data_list | |
def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]: | |
"""Serialize ``self.data_list`` to save memory when launching multiple | |
workers in data loading. This function will be called in ``full_init``. | |
Hold memory using serialized objects, and data loader workers can use | |
shared RAM from master process instead of making a copy. | |
Returns: | |
Tuple[np.ndarray, np.ndarray]: Serialized result and corresponding | |
address. | |
""" | |
def _serialize(data): | |
buffer = pickle.dumps(data, protocol=4) | |
return np.frombuffer(buffer, dtype=np.uint8) | |
# Serialize data information list avoid making multiple copies of | |
# `self.data_list` when iterate `import torch.utils.data.dataloader` | |
# with multiple workers. | |
data_list = [_serialize(x) for x in self.data_list] | |
address_list = np.asarray([len(x) for x in data_list], dtype=np.int64) | |
data_address: np.ndarray = np.cumsum(address_list) | |
# TODO Check if np.concatenate is necessary | |
data_bytes = np.concatenate(data_list) | |
# Empty cache for preventing making multiple copies of | |
# `self.data_info` when loading data multi-processes. | |
self.data_list.clear() | |
gc.collect() | |
return data_bytes, data_address | |
def _rand_another(self) -> int: | |
"""Get random index. | |
Returns: | |
int: Random index from 0 to ``len(self)-1`` | |
""" | |
return np.random.randint(0, len(self)) | |
def prepare_data(self, idx) -> Any: | |
"""Get data processed by ``self.pipeline``. | |
Args: | |
idx (int): The index of ``data_info``. | |
Returns: | |
Any: Depends on ``self.pipeline``. | |
""" | |
data_info = self.get_data_info(idx) | |
return self.pipeline(data_info) | |
def __len__(self) -> int: | |
"""Get the length of filtered dataset and automatically call | |
``full_init`` if the dataset has not been fully init. | |
Returns: | |
int: The length of filtered dataset. | |
""" | |
if self.serialize_data: | |
return len(self.data_address) | |
else: | |
return len(self.data_list) | |
def _copy_without_annotation(self, memo=dict()) -> 'BaseDataset': | |
"""Deepcopy for all attributes other than ``data_list``, | |
``data_address`` and ``data_bytes``. | |
Args: | |
memo: Memory dict which used to reconstruct complex object | |
correctly. | |
""" | |
cls = self.__class__ | |
other = cls.__new__(cls) | |
memo[id(self)] = other | |
for key, value in self.__dict__.items(): | |
if key in ['data_list', 'data_address', 'data_bytes']: | |
continue | |
super(BaseDataset, other).__setattr__(key, | |
copy.deepcopy(value, memo)) | |
return other | |