Spaces:
Runtime error
Runtime error
from typing import Callable, Dict, List, Optional, Sequence, Union | |
import cv2 | |
import numpy as np | |
from mmcv.transforms import TRANSFORMS | |
from mmcv.transforms.utils import cache_random_params | |
from mmcv.transforms.wrappers import * | |
# Define type of transform or transform config | |
Transform = Union[Dict, Callable[[Dict], Dict]] | |
# Indicator of keys marked by KeyMapper._map_input, which means ignoring the | |
# marked keys in KeyMapper._apply_transform so they will be invisible to | |
# wrapped transforms. | |
# This can be 2 possible case: | |
# 1. The key is required but missing in results | |
# 2. The key is manually set as ... (Ellipsis) in ``mapping``, which means | |
# the original value in results should be ignored | |
IgnoreKey = object() | |
# Import nullcontext if python>=3.7, otherwise use a simple alternative | |
# implementation. | |
try: | |
from contextlib import nullcontext # type: ignore | |
except ImportError: | |
from contextlib import contextmanager | |
# type: ignore | |
def nullcontext(resource=None): | |
try: | |
yield resource | |
finally: | |
pass | |
def imdenormalize(img, mean, std, to_bgr=True): | |
assert img.dtype != np.uint8 | |
mean = mean.reshape(1, -1).astype(np.float64) | |
std = std.reshape(1, -1).astype(np.float64) | |
img = cv2.multiply(img, std) # make a copy | |
cv2.add(img, mean, img) # inplace | |
if to_bgr: | |
cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace | |
return img | |
class MasaTransformBroadcaster(KeyMapper): | |
"""A transform wrapper to apply the wrapped transforms to multiple data | |
items. For example, apply Resize to multiple images. | |
Args: | |
transforms (list[dict | callable]): Sequence of transform object or | |
config dict to be wrapped. | |
mapping (dict): A dict that defines the input key mapping. | |
Note that to apply the transforms to multiple data items, the | |
outer keys of the target items should be remapped as a list with | |
the standard inner key (The key required by the wrapped transform). | |
See the following example and the document of | |
``mmcv.transforms.wrappers.KeyMapper`` for details. | |
remapping (dict): A dict that defines the output key mapping. | |
The keys and values have the same meanings and rules as in the | |
``mapping``. Default: None. | |
auto_remap (bool, optional): If True, an inverse of the mapping will | |
be used as the remapping. If auto_remap is not given, it will be | |
automatically set True if 'remapping' is not given, and vice | |
versa. Default: None. | |
allow_nonexist_keys (bool): If False, the outer keys in the mapping | |
must exist in the input data, or an exception will be raised. | |
Default: False. | |
share_random_params (bool): If True, the random transform | |
(e.g., RandomFlip) will be conducted in a deterministic way and | |
have the same behavior on all data items. For example, to randomly | |
flip either both input image and ground-truth image, or none. | |
Default: False. | |
""" | |
def __init__( | |
self, | |
transforms: List[Union[Dict, Callable[[Dict], Dict]]], | |
mapping: Optional[Dict] = None, | |
remapping: Optional[Dict] = None, | |
auto_remap: Optional[bool] = None, | |
allow_nonexist_keys: bool = False, | |
share_random_params: bool = False, | |
): | |
super().__init__( | |
transforms, mapping, remapping, auto_remap, allow_nonexist_keys | |
) | |
self.share_random_params = share_random_params | |
def scatter_sequence(self, data: Dict) -> List[Dict]: | |
"""Scatter the broadcasting targets to a list of inputs of the wrapped | |
transforms.""" | |
# infer split number from input | |
seq_len = 0 | |
key_rep = None | |
if self.mapping: | |
keys = self.mapping.keys() | |
else: | |
keys = data.keys() | |
for key in keys: | |
assert isinstance(data[key], Sequence) | |
if seq_len: | |
if len(data[key]) != seq_len: | |
raise ValueError( | |
"Got inconsistent sequence length: " | |
f"{seq_len} ({key_rep}) vs. " | |
f"{len(data[key])} ({key})" | |
) | |
else: | |
seq_len = len(data[key]) | |
key_rep = key | |
assert seq_len > 0, "Fail to get the number of broadcasting targets" | |
scatters = [] | |
for i in range(seq_len): # type: ignore | |
scatter = data.copy() | |
for key in keys: | |
scatter[key] = data[key][i] | |
scatters.append(scatter) | |
return scatters | |
def transform(self, results: Dict): | |
"""Broadcast wrapped transforms to multiple targets.""" | |
# Apply input remapping | |
inputs = self._map_input(results, self.mapping) | |
# Scatter sequential inputs into a list | |
input_scatters = self.scatter_sequence(inputs) | |
# Control random parameter sharing with a context manager | |
if self.share_random_params: | |
# The context manager :func`:cache_random_params` will let | |
# cacheable method of the transforms cache their outputs. Thus | |
# the random parameters will only generated once and shared | |
# by all data items. | |
ctx = cache_random_params # type: ignore | |
else: | |
ctx = nullcontext # type: ignore | |
with ctx(self.transforms): | |
output_scatters = [ | |
self._apply_transforms(_input) for _input in input_scatters | |
] | |
outputs = { | |
key: [_output[key] for _output in output_scatters] | |
for key in output_scatters[0] | |
} | |
# Apply remapping | |
outputs = self._map_output(outputs, self.remapping) | |
results.update(outputs) | |
return results | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
repr_str += f"(transforms = {self.transforms}" | |
repr_str += f", mapping = {self.mapping}" | |
repr_str += f", remapping = {self.remapping}" | |
repr_str += f", auto_remap = {self.auto_remap}" | |
repr_str += f", allow_nonexist_keys = {self.allow_nonexist_keys}" | |
repr_str += f", share_random_params = {self.share_random_params})" | |
return repr_str | |