Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Any, Callable, Dict, List, Optional, Sequence, Union | |
import mmengine | |
import numpy as np | |
from .base import BaseTransform | |
from .builder import TRANSFORMS | |
from .utils import cache_random_params, cache_randomness | |
# Define type of transform or transform config | |
Transform = Union[Dict, Callable[[Dict], Dict]] | |
# Indicator of keys marked by KeyMapper._map_input, which means ignoring the | |
# marked keys in KeyMapper._apply_transform so they will be invisible to | |
# wrapped transforms. | |
# This can be 2 possible case: | |
# 1. The key is required but missing in results | |
# 2. The key is manually set as ... (Ellipsis) in ``mapping``, which means | |
# the original value in results should be ignored | |
IgnoreKey = object() | |
# Import nullcontext if python>=3.7, otherwise use a simple alternative | |
# implementation. | |
try: | |
from contextlib import nullcontext # type: ignore | |
except ImportError: | |
from contextlib import contextmanager | |
# type: ignore | |
def nullcontext(resource=None): | |
try: | |
yield resource | |
finally: | |
pass | |
class Compose(BaseTransform): | |
"""Compose multiple transforms sequentially. | |
Args: | |
transforms (list[dict | callable]): Sequence of transform object or | |
config dict to be composed. | |
Examples: | |
>>> pipeline = [ | |
>>> dict(type='Compose', | |
>>> transforms=[ | |
>>> dict(type='LoadImageFromFile'), | |
>>> dict(type='Normalize') | |
>>> ] | |
>>> ) | |
>>> ] | |
""" | |
def __init__(self, transforms: Union[Transform, Sequence[Transform]]): | |
super().__init__() | |
if not isinstance(transforms, Sequence): | |
transforms = [transforms] | |
self.transforms: List = [] | |
for transform in transforms: | |
if isinstance(transform, dict): | |
transform = TRANSFORMS.build(transform) | |
self.transforms.append(transform) | |
elif callable(transform): | |
self.transforms.append(transform) | |
else: | |
raise TypeError('transform must be callable or a dict, but got' | |
f' {type(transform)}') | |
def __iter__(self): | |
"""Allow easy iteration over the transform sequence.""" | |
return iter(self.transforms) | |
def transform(self, results: Dict) -> Optional[Dict]: | |
"""Call function to apply transforms sequentially. | |
Args: | |
results (dict): A result dict contains the results to transform. | |
Returns: | |
dict or None: Transformed results. | |
""" | |
for t in self.transforms: | |
results = t(results) # type: ignore | |
if results is None: | |
return None | |
return results | |
def __repr__(self): | |
"""Compute the string representation.""" | |
format_string = self.__class__.__name__ + '(' | |
for t in self.transforms: | |
format_string += f'\n {t}' | |
format_string += '\n)' | |
return format_string | |
class KeyMapper(BaseTransform): | |
"""A transform wrapper to map and reorganize the input/output of the | |
wrapped transforms (or sub-pipeline). | |
Args: | |
transforms (list[dict | callable], optional): Sequence of transform | |
object or config dict to be wrapped. | |
mapping (dict): A dict that defines the input key mapping. | |
The keys corresponds to the inner key (i.e., kwargs of the | |
``transform`` method), and should be string type. The values | |
corresponds to the outer keys (i.e., the keys of the | |
data/results), and should have a type of string, list or dict. | |
None means not applying input mapping. Default: None. | |
remapping (dict): A dict that defines the output key mapping. | |
The keys and values have the same meanings and rules as in the | |
``mapping``. Default: None. | |
auto_remap (bool, optional): If True, an inverse of the mapping will | |
be used as the remapping. If auto_remap is not given, it will be | |
automatically set True if 'remapping' is not given, and vice | |
versa. Default: None. | |
allow_nonexist_keys (bool): If False, the outer keys in the mapping | |
must exist in the input data, or an exception will be raised. | |
Default: False. | |
Examples: | |
>>> # Example 1: KeyMapper 'gt_img' to 'img' | |
>>> pipeline = [ | |
>>> # Use KeyMapper to convert outer (original) field name | |
>>> # 'gt_img' to inner (used by inner transforms) filed name | |
>>> # 'img' | |
>>> dict(type='KeyMapper', | |
>>> mapping={'img': 'gt_img'}, | |
>>> # auto_remap=True means output key mapping is the revert of | |
>>> # the input key mapping, e.g. inner 'img' will be mapped | |
>>> # back to outer 'gt_img' | |
>>> auto_remap=True, | |
>>> transforms=[ | |
>>> # In all transforms' implementation just use 'img' | |
>>> # as a standard field name | |
>>> dict(type='Crop', crop_size=(384, 384)), | |
>>> dict(type='Normalize'), | |
>>> ]) | |
>>> ] | |
>>> # Example 2: Collect and structure multiple items | |
>>> pipeline = [ | |
>>> # The inner field 'imgs' will be a dict with keys 'img_src' | |
>>> # and 'img_tar', whose values are outer fields 'img1' and | |
>>> # 'img2' respectively. | |
>>> dict(type='KeyMapper', | |
>>> dict( | |
>>> type='KeyMapper', | |
>>> mapping=dict( | |
>>> imgs=dict( | |
>>> img_src='img1', | |
>>> img_tar='img2')), | |
>>> transforms=...) | |
>>> ] | |
>>> # Example 3: Manually set ignored keys by "..." | |
>>> pipeline = [ | |
>>> ... | |
>>> dict(type='KeyMapper', | |
>>> mapping={ | |
>>> # map outer key "gt_img" to inner key "img" | |
>>> 'img': 'gt_img', | |
>>> # ignore outer key "mask" | |
>>> 'mask': ..., | |
>>> }, | |
>>> transforms=[ | |
>>> dict(type='RandomFlip'), | |
>>> ]) | |
>>> ... | |
>>> ] | |
""" | |
def __init__(self, | |
transforms: Union[Transform, List[Transform]] = None, | |
mapping: Optional[Dict] = None, | |
remapping: Optional[Dict] = None, | |
auto_remap: Optional[bool] = None, | |
allow_nonexist_keys: bool = False): | |
super().__init__() | |
self.allow_nonexist_keys = allow_nonexist_keys | |
self.mapping = mapping | |
if auto_remap is None: | |
auto_remap = remapping is None | |
self.auto_remap = auto_remap | |
if self.auto_remap: | |
if remapping is not None: | |
raise ValueError('KeyMapper: ``remapping`` must be None if' | |
'`auto_remap` is set True.') | |
self.remapping = mapping | |
else: | |
self.remapping = remapping | |
if transforms is None: | |
transforms = [] | |
self.transforms = Compose(transforms) | |
def __iter__(self): | |
"""Allow easy iteration over the transform sequence.""" | |
return iter(self.transforms) | |
def _map_input(self, data: Dict, | |
mapping: Optional[Dict]) -> Dict[str, Any]: | |
"""KeyMapper inputs for the wrapped transforms by gathering and | |
renaming data items according to the mapping. | |
Args: | |
data (dict): The original input data | |
mapping (dict, optional): The input key mapping. See the document | |
of ``mmcv.transforms.wrappers.KeyMapper`` for details. In | |
set None, return the input data directly. | |
Returns: | |
dict: The input data with remapped keys. This will be the actual | |
input of the wrapped pipeline. | |
""" | |
if mapping is None: | |
return data.copy() | |
def _map(data, m): | |
if isinstance(m, dict): | |
# m is a dict {inner_key:outer_key, ...} | |
return {k_in: _map(data, k_out) for k_in, k_out in m.items()} | |
if isinstance(m, (tuple, list)): | |
# m is a list or tuple [outer_key1, outer_key2, ...] | |
# This is the case when we collect items from the original | |
# data to form a list or tuple to feed to the wrapped | |
# transforms. | |
return m.__class__(_map(data, e) for e in m) | |
# allow manually mark a key to be ignored by ... | |
if m is ...: | |
return IgnoreKey | |
# m is an outer_key | |
if self.allow_nonexist_keys: | |
return data.get(m, IgnoreKey) | |
else: | |
return data.get(m) | |
collected = _map(data, mapping) | |
# Retain unmapped items | |
inputs = data.copy() | |
inputs.update(collected) | |
return inputs | |
def _map_output(self, data: Dict, | |
remapping: Optional[Dict]) -> Dict[str, Any]: | |
"""KeyMapper outputs from the wrapped transforms by gathering and | |
renaming data items according to the remapping. | |
Args: | |
data (dict): The output of the wrapped pipeline. | |
remapping (dict, optional): The output key mapping. See the | |
document of ``mmcv.transforms.wrappers.KeyMapper`` for | |
details. If ``remapping is None``, no key mapping will be | |
applied but only remove the special token ``IgnoreKey``. | |
Returns: | |
dict: The output with remapped keys. | |
""" | |
# Remove ``IgnoreKey`` | |
if remapping is None: | |
return {k: v for k, v in data.items() if v is not IgnoreKey} | |
def _map(data, m): | |
if isinstance(m, dict): | |
assert isinstance(data, dict) | |
results = {} | |
for k_in, k_out in m.items(): | |
assert k_in in data | |
results.update(_map(data[k_in], k_out)) | |
return results | |
if isinstance(m, (list, tuple)): | |
assert isinstance(data, (list, tuple)) | |
assert len(data) == len(m) | |
results = {} | |
for m_i, d_i in zip(m, data): | |
results.update(_map(d_i, m_i)) | |
return results | |
# ``m is ...`` means the key is marked ignored, in which case the | |
# inner resuls will not affect the outer results in remapping. | |
# Another case that will have ``data is IgnoreKey`` is that the | |
# key is missing in the inputs. In this case, if the inner key is | |
# created by the wrapped transforms, it will be remapped to the | |
# corresponding outer key during remapping. | |
if m is ... or data is IgnoreKey: | |
return {} | |
return {m: data} | |
# Note that unmapped items are not retained, which is different from | |
# the behavior in _map_input. This is to avoid original data items | |
# being overwritten by intermediate namesakes | |
return _map(data, remapping) | |
def _apply_transforms(self, inputs: Dict) -> Dict: | |
"""Apply ``self.transforms``. | |
Note that the special token ``IgnoreKey`` will be invisible to | |
``self.transforms``, but not removed in this method. It will be | |
eventually removed in :func:``self._map_output``. | |
""" | |
results = inputs.copy() | |
inputs = {k: v for k, v in inputs.items() if v is not IgnoreKey} | |
outputs = self.transforms(inputs) | |
if outputs is None: | |
raise ValueError( | |
f'Transforms wrapped by {self.__class__.__name__} should ' | |
'not return None.') | |
results.update(outputs) # type: ignore | |
return results | |
def transform(self, results: Dict) -> Dict: | |
"""Apply mapping, wrapped transforms and remapping.""" | |
# Apply mapping | |
inputs = self._map_input(results, self.mapping) | |
# Apply wrapped transforms | |
outputs = self._apply_transforms(inputs) | |
# Apply remapping | |
outputs = self._map_output(outputs, self.remapping) | |
results.update(outputs) # type: ignore | |
return results | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
repr_str += f'(transforms = {self.transforms}' | |
repr_str += f', mapping = {self.mapping}' | |
repr_str += f', remapping = {self.remapping}' | |
repr_str += f', auto_remap = {self.auto_remap}' | |
repr_str += f', allow_nonexist_keys = {self.allow_nonexist_keys})' | |
return repr_str | |
class TransformBroadcaster(KeyMapper): | |
"""A transform wrapper to apply the wrapped transforms to multiple data | |
items. For example, apply Resize to multiple images. | |
Args: | |
transforms (list[dict | callable]): Sequence of transform object or | |
config dict to be wrapped. | |
mapping (dict): A dict that defines the input key mapping. | |
Note that to apply the transforms to multiple data items, the | |
outer keys of the target items should be remapped as a list with | |
the standard inner key (The key required by the wrapped transform). | |
See the following example and the document of | |
``mmcv.transforms.wrappers.KeyMapper`` for details. | |
remapping (dict): A dict that defines the output key mapping. | |
The keys and values have the same meanings and rules as in the | |
``mapping``. Default: None. | |
auto_remap (bool, optional): If True, an inverse of the mapping will | |
be used as the remapping. If auto_remap is not given, it will be | |
automatically set True if 'remapping' is not given, and vice | |
versa. Default: None. | |
allow_nonexist_keys (bool): If False, the outer keys in the mapping | |
must exist in the input data, or an exception will be raised. | |
Default: False. | |
share_random_params (bool): If True, the random transform | |
(e.g., RandomFlip) will be conducted in a deterministic way and | |
have the same behavior on all data items. For example, to randomly | |
flip either both input image and ground-truth image, or none. | |
Default: False. | |
.. note:: | |
To apply the transforms to each elements of a list or tuple, instead | |
of separating data items, you can map the outer key of the target | |
sequence to the standard inner key. See example 2. | |
example. | |
Examples: | |
>>> # Example 1: Broadcast to enumerated keys, each contains a single | |
>>> # data element | |
>>> pipeline = [ | |
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img | |
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img | |
>>> # TransformBroadcaster maps multiple outer fields to standard | |
>>> # the inner field and process them with wrapped transforms | |
>>> # respectively | |
>>> dict(type='TransformBroadcaster', | |
>>> # case 1: from multiple outer fields | |
>>> mapping={'img': ['lq', 'gt']}, | |
>>> auto_remap=True, | |
>>> # share_random_param=True means using identical random | |
>>> # parameters in every processing | |
>>> share_random_param=True, | |
>>> transforms=[ | |
>>> dict(type='Crop', crop_size=(384, 384)), | |
>>> dict(type='Normalize'), | |
>>> ]) | |
>>> ] | |
>>> # Example 2: Broadcast to keys that contains data sequences | |
>>> pipeline = [ | |
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img | |
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img | |
>>> # TransformBroadcaster maps multiple outer fields to standard | |
>>> # the inner field and process them with wrapped transforms | |
>>> # respectively | |
>>> dict(type='TransformBroadcaster', | |
>>> # case 2: from one outer field that contains multiple | |
>>> # data elements (e.g. a list) | |
>>> # mapping={'img': 'images'}, | |
>>> auto_remap=True, | |
>>> share_random_param=True, | |
>>> transforms=[ | |
>>> dict(type='Crop', crop_size=(384, 384)), | |
>>> dict(type='Normalize'), | |
>>> ]) | |
>>> ] | |
>>> Example 3: Set ignored keys in broadcasting | |
>>> pipeline = [ | |
>>> dict(type='TransformBroadcaster', | |
>>> # Broadcast the wrapped transforms to multiple images | |
>>> # 'lq' and 'gt, but only update 'img_shape' once | |
>>> mapping={ | |
>>> 'img': ['lq', 'gt'], | |
>>> 'img_shape': ['img_shape', ...], | |
>>> }, | |
>>> auto_remap=True, | |
>>> share_random_params=True, | |
>>> transforms=[ | |
>>> # `RandomCrop` will modify the field "img", | |
>>> # and optionally update "img_shape" if it exists | |
>>> dict(type='RandomCrop'), | |
>>> ]) | |
>>> ] | |
""" | |
def __init__(self, | |
transforms: List[Union[Dict, Callable[[Dict], Dict]]], | |
mapping: Optional[Dict] = None, | |
remapping: Optional[Dict] = None, | |
auto_remap: Optional[bool] = None, | |
allow_nonexist_keys: bool = False, | |
share_random_params: bool = False): | |
super().__init__(transforms, mapping, remapping, auto_remap, | |
allow_nonexist_keys) | |
self.share_random_params = share_random_params | |
def scatter_sequence(self, data: Dict) -> List[Dict]: | |
"""Scatter the broadcasting targets to a list of inputs of the wrapped | |
transforms.""" | |
# infer split number from input | |
seq_len = 0 | |
key_rep = None | |
if self.mapping: | |
keys = self.mapping.keys() | |
else: | |
keys = data.keys() | |
for key in keys: | |
assert isinstance(data[key], Sequence) | |
if seq_len: | |
if len(data[key]) != seq_len: | |
raise ValueError('Got inconsistent sequence length: ' | |
f'{seq_len} ({key_rep}) vs. ' | |
f'{len(data[key])} ({key})') | |
else: | |
seq_len = len(data[key]) | |
key_rep = key | |
assert seq_len > 0, 'Fail to get the number of broadcasting targets' | |
scatters = [] | |
for i in range(seq_len): # type: ignore | |
scatter = data.copy() | |
for key in keys: | |
scatter[key] = data[key][i] | |
scatters.append(scatter) | |
return scatters | |
def transform(self, results: Dict): | |
"""Broadcast wrapped transforms to multiple targets.""" | |
# Apply input remapping | |
inputs = self._map_input(results, self.mapping) | |
# Scatter sequential inputs into a list | |
input_scatters = self.scatter_sequence(inputs) | |
# Control random parameter sharing with a context manager | |
if self.share_random_params: | |
# The context manager :func`:cache_random_params` will let | |
# cacheable method of the transforms cache their outputs. Thus | |
# the random parameters will only generated once and shared | |
# by all data items. | |
ctx = cache_random_params # type: ignore | |
else: | |
ctx = nullcontext # type: ignore | |
with ctx(self.transforms): | |
output_scatters = [ | |
self._apply_transforms(_input) for _input in input_scatters | |
] | |
# Collate output scatters (list of dict to dict of list) | |
outputs = { | |
key: [_output[key] for _output in output_scatters] | |
for key in output_scatters[0] | |
} | |
# Apply remapping | |
outputs = self._map_output(outputs, self.remapping) | |
results.update(outputs) | |
return results | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
repr_str += f'(transforms = {self.transforms}' | |
repr_str += f', mapping = {self.mapping}' | |
repr_str += f', remapping = {self.remapping}' | |
repr_str += f', auto_remap = {self.auto_remap}' | |
repr_str += f', allow_nonexist_keys = {self.allow_nonexist_keys}' | |
repr_str += f', share_random_params = {self.share_random_params})' | |
return repr_str | |
class RandomChoice(BaseTransform): | |
"""Process data with a randomly chosen transform from given candidates. | |
Args: | |
transforms (list[list]): A list of transform candidates, each is a | |
sequence of transforms. | |
prob (list[float], optional): The probabilities associated | |
with each pipeline. The length should be equal to the pipeline | |
number and the sum should be 1. If not given, a uniform | |
distribution will be assumed. | |
Examples: | |
>>> # config | |
>>> pipeline = [ | |
>>> dict(type='RandomChoice', | |
>>> transforms=[ | |
>>> [dict(type='RandomHorizontalFlip')], # subpipeline 1 | |
>>> [dict(type='RandomRotate')], # subpipeline 2 | |
>>> ] | |
>>> ) | |
>>> ] | |
""" | |
def __init__(self, | |
transforms: List[Union[Transform, List[Transform]]], | |
prob: Optional[List[float]] = None): | |
super().__init__() | |
if prob is not None: | |
assert mmengine.is_seq_of(prob, float) | |
assert len(transforms) == len(prob), \ | |
'``transforms`` and ``prob`` must have same lengths. ' \ | |
f'Got {len(transforms)} vs {len(prob)}.' | |
assert sum(prob) == 1 | |
self.prob = prob | |
self.transforms = [Compose(transforms) for transforms in transforms] | |
def __iter__(self): | |
return iter(self.transforms) | |
def random_pipeline_index(self) -> int: | |
"""Return a random transform index.""" | |
indices = np.arange(len(self.transforms)) | |
return np.random.choice(indices, p=self.prob) | |
def transform(self, results: Dict) -> Optional[Dict]: | |
"""Randomly choose a transform to apply.""" | |
idx = self.random_pipeline_index() | |
return self.transforms[idx](results) | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
repr_str += f'(transforms = {self.transforms}' | |
repr_str += f'prob = {self.prob})' | |
return repr_str | |
class RandomApply(BaseTransform): | |
"""Apply transforms randomly with a given probability. | |
Args: | |
transforms (list[dict | callable]): The transform or transform list | |
to randomly apply. | |
prob (float): The probability to apply transforms. Default: 0.5 | |
Examples: | |
>>> # config | |
>>> pipeline = [ | |
>>> dict(type='RandomApply', | |
>>> transforms=[dict(type='HorizontalFlip')], | |
>>> prob=0.3) | |
>>> ] | |
""" | |
def __init__(self, | |
transforms: Union[Transform, List[Transform]], | |
prob: float = 0.5): | |
super().__init__() | |
self.prob = prob | |
self.transforms = Compose(transforms) | |
def __iter__(self): | |
return iter(self.transforms) | |
def random_apply(self) -> bool: | |
"""Return a random bool value indicating whether apply the | |
transform.""" | |
return np.random.rand() < self.prob | |
def transform(self, results: Dict) -> Optional[Dict]: | |
"""Randomly apply the transform.""" | |
if self.random_apply(): | |
return self.transforms(results) # type: ignore | |
else: | |
return results | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
repr_str += f'(transforms = {self.transforms}' | |
repr_str += f', prob = {self.prob})' | |
return repr_str | |