Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Sequence, Union | |
import mmengine | |
import numpy as np | |
import torch | |
from .base import BaseTransform | |
from .builder import TRANSFORMS | |
def to_tensor( | |
data: Union[torch.Tensor, np.ndarray, Sequence, int, | |
float]) -> torch.Tensor: | |
"""Convert objects of various python types to :obj:`torch.Tensor`. | |
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |
:class:`Sequence`, :class:`int` and :class:`float`. | |
Args: | |
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to | |
be converted. | |
Returns: | |
torch.Tensor: the converted data. | |
""" | |
if isinstance(data, torch.Tensor): | |
return data | |
elif isinstance(data, np.ndarray): | |
return torch.from_numpy(data) | |
elif isinstance(data, Sequence) and not mmengine.is_str(data): | |
return torch.tensor(data) | |
elif isinstance(data, int): | |
return torch.LongTensor([data]) | |
elif isinstance(data, float): | |
return torch.FloatTensor([data]) | |
else: | |
raise TypeError(f'type {type(data)} cannot be converted to tensor.') | |
class ToTensor(BaseTransform): | |
"""Convert some results to :obj:`torch.Tensor` by given keys. | |
Required keys: | |
- all these keys in `keys` | |
Modified Keys: | |
- all these keys in `keys` | |
Args: | |
keys (Sequence[str]): Keys that need to be converted to Tensor. | |
""" | |
def __init__(self, keys: Sequence[str]) -> None: | |
self.keys = keys | |
def transform(self, results: dict) -> dict: | |
"""Transform function to convert data to `torch.Tensor`. | |
Args: | |
results (dict): Result dict from loading pipeline. | |
Returns: | |
dict: `keys` in results will be updated. | |
""" | |
for key in self.keys: | |
key_list = key.split('.') | |
cur_item = results | |
for i in range(len(key_list)): | |
if key_list[i] not in cur_item: | |
raise KeyError(f'Can not find key {key}') | |
if i == len(key_list) - 1: | |
cur_item[key_list[i]] = to_tensor(cur_item[key_list[i]]) | |
break | |
cur_item = cur_item[key_list[i]] | |
return results | |
def __repr__(self) -> str: | |
return self.__class__.__name__ + f'(keys={self.keys})' | |
class ImageToTensor(BaseTransform): | |
"""Convert image to :obj:`torch.Tensor` by given keys. | |
The dimension order of input image is (H, W, C). The pipeline will convert | |
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be | |
(1, H, W). | |
Required keys: | |
- all these keys in `keys` | |
Modified Keys: | |
- all these keys in `keys` | |
Args: | |
keys (Sequence[str]): Key of images to be converted to Tensor. | |
""" | |
def __init__(self, keys: dict) -> None: | |
self.keys = keys | |
def transform(self, results: dict) -> dict: | |
"""Transform function to convert image in results to | |
:obj:`torch.Tensor` and transpose the channel order. | |
Args: | |
results (dict): Result dict contains the image data to convert. | |
Returns: | |
dict: The result dict contains the image converted | |
to :obj:``torch.Tensor`` and transposed to (C, H, W) order. | |
""" | |
for key in self.keys: | |
img = results[key] | |
if len(img.shape) < 3: | |
img = np.expand_dims(img, -1) | |
results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous() | |
return results | |
def __repr__(self) -> str: | |
return self.__class__.__name__ + f'(keys={self.keys})' | |