JohanDL's picture
initial commit
f1dd031
raw
history blame
6.44 kB
from typing import Callable, Dict, List, Optional, Sequence, Union
import cv2
import numpy as np
from mmcv.transforms import TRANSFORMS
from mmcv.transforms.utils import cache_random_params
from mmcv.transforms.wrappers import *
# Define type of transform or transform config
Transform = Union[Dict, Callable[[Dict], Dict]]
# Indicator of keys marked by KeyMapper._map_input, which means ignoring the
# marked keys in KeyMapper._apply_transform so they will be invisible to
# wrapped transforms.
# This can be 2 possible case:
# 1. The key is required but missing in results
# 2. The key is manually set as ... (Ellipsis) in ``mapping``, which means
# the original value in results should be ignored
IgnoreKey = object()
# Import nullcontext if python>=3.7, otherwise use a simple alternative
# implementation.
try:
from contextlib import nullcontext # type: ignore
except ImportError:
from contextlib import contextmanager
@contextmanager # type: ignore
def nullcontext(resource=None):
try:
yield resource
finally:
pass
def imdenormalize(img, mean, std, to_bgr=True):
assert img.dtype != np.uint8
mean = mean.reshape(1, -1).astype(np.float64)
std = std.reshape(1, -1).astype(np.float64)
img = cv2.multiply(img, std) # make a copy
cv2.add(img, mean, img) # inplace
if to_bgr:
cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
return img
@TRANSFORMS.register_module()
class MasaTransformBroadcaster(KeyMapper):
"""A transform wrapper to apply the wrapped transforms to multiple data
items. For example, apply Resize to multiple images.
Args:
transforms (list[dict | callable]): Sequence of transform object or
config dict to be wrapped.
mapping (dict): A dict that defines the input key mapping.
Note that to apply the transforms to multiple data items, the
outer keys of the target items should be remapped as a list with
the standard inner key (The key required by the wrapped transform).
See the following example and the document of
``mmcv.transforms.wrappers.KeyMapper`` for details.
remapping (dict): A dict that defines the output key mapping.
The keys and values have the same meanings and rules as in the
``mapping``. Default: None.
auto_remap (bool, optional): If True, an inverse of the mapping will
be used as the remapping. If auto_remap is not given, it will be
automatically set True if 'remapping' is not given, and vice
versa. Default: None.
allow_nonexist_keys (bool): If False, the outer keys in the mapping
must exist in the input data, or an exception will be raised.
Default: False.
share_random_params (bool): If True, the random transform
(e.g., RandomFlip) will be conducted in a deterministic way and
have the same behavior on all data items. For example, to randomly
flip either both input image and ground-truth image, or none.
Default: False.
"""
def __init__(
self,
transforms: List[Union[Dict, Callable[[Dict], Dict]]],
mapping: Optional[Dict] = None,
remapping: Optional[Dict] = None,
auto_remap: Optional[bool] = None,
allow_nonexist_keys: bool = False,
share_random_params: bool = False,
):
super().__init__(
transforms, mapping, remapping, auto_remap, allow_nonexist_keys
)
self.share_random_params = share_random_params
def scatter_sequence(self, data: Dict) -> List[Dict]:
"""Scatter the broadcasting targets to a list of inputs of the wrapped
transforms."""
# infer split number from input
seq_len = 0
key_rep = None
if self.mapping:
keys = self.mapping.keys()
else:
keys = data.keys()
for key in keys:
assert isinstance(data[key], Sequence)
if seq_len:
if len(data[key]) != seq_len:
raise ValueError(
"Got inconsistent sequence length: "
f"{seq_len} ({key_rep}) vs. "
f"{len(data[key])} ({key})"
)
else:
seq_len = len(data[key])
key_rep = key
assert seq_len > 0, "Fail to get the number of broadcasting targets"
scatters = []
for i in range(seq_len): # type: ignore
scatter = data.copy()
for key in keys:
scatter[key] = data[key][i]
scatters.append(scatter)
return scatters
def transform(self, results: Dict):
"""Broadcast wrapped transforms to multiple targets."""
# Apply input remapping
inputs = self._map_input(results, self.mapping)
# Scatter sequential inputs into a list
input_scatters = self.scatter_sequence(inputs)
# Control random parameter sharing with a context manager
if self.share_random_params:
# The context manager :func`:cache_random_params` will let
# cacheable method of the transforms cache their outputs. Thus
# the random parameters will only generated once and shared
# by all data items.
ctx = cache_random_params # type: ignore
else:
ctx = nullcontext # type: ignore
with ctx(self.transforms):
output_scatters = [
self._apply_transforms(_input) for _input in input_scatters
]
outputs = {
key: [_output[key] for _output in output_scatters]
for key in output_scatters[0]
}
# Apply remapping
outputs = self._map_output(outputs, self.remapping)
results.update(outputs)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f"(transforms = {self.transforms}"
repr_str += f", mapping = {self.mapping}"
repr_str += f", remapping = {self.remapping}"
repr_str += f", auto_remap = {self.auto_remap}"
repr_str += f", allow_nonexist_keys = {self.allow_nonexist_keys}"
repr_str += f", share_random_params = {self.share_random_params})"
return repr_str