rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import List, Sequence, Union
import numpy as np
import torch
from .base_data_element import BaseDataElement
class PixelData(BaseDataElement):
"""Data structure for pixel-level annotations or predictions.
All data items in ``data_fields`` of ``PixelData`` meet the following
requirements:
- They all have 3 dimensions in orders of channel, height, and width.
- They should have the same height and width.
Examples:
>>> metainfo = dict(
... img_id=random.randint(0, 100),
... img_shape=(random.randint(400, 600), random.randint(400, 600)))
>>> image = np.random.randint(0, 255, (4, 20, 40))
>>> featmap = torch.randint(0, 255, (10, 20, 40))
>>> pixel_data = PixelData(metainfo=metainfo,
... image=image,
... featmap=featmap)
>>> print(pixel_data.shape)
(20, 40)
>>> # slice
>>> slice_data = pixel_data[10:20, 20:40]
>>> assert slice_data.shape == (10, 20)
>>> slice_data = pixel_data[10, 20]
>>> assert slice_data.shape == (1, 1)
>>> # set
>>> pixel_data.map3 = torch.randint(0, 255, (20, 40))
>>> assert tuple(pixel_data.map3.shape) == (1, 20, 40)
>>> with self.assertRaises(AssertionError):
... # The dimension must be 3 or 2
... pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40))
"""
def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray]):
"""Set attributes of ``PixelData``.
If the dimension of value is 2 and its shape meet the demand, it
will automatically expand its channel-dimension.
Args:
name (str): The key to access the value, stored in `PixelData`.
value (Union[torch.Tensor, np.ndarray]): The value to store in.
The type of value must be `torch.Tensor` or `np.ndarray`,
and its shape must meet the requirements of `PixelData`.
"""
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(f'{name} has been used as a '
'private attribute, which is immutable.')
else:
assert isinstance(value, (torch.Tensor, np.ndarray)), \
f'Can not set {type(value)}, only support' \
f' {(torch.Tensor, np.ndarray)}'
if self.shape:
assert tuple(value.shape[-2:]) == self.shape, (
'The height and width of '
f'values {tuple(value.shape[-2:])} is '
'not consistent with '
'the shape of this '
':obj:`PixelData` '
f'{self.shape}')
assert value.ndim in [
2, 3
], f'The dim of value must be 2 or 3, but got {value.ndim}'
if value.ndim == 2:
value = value[None]
warnings.warn('The shape of value will convert from '
f'{value.shape[-2:]} to {value.shape}')
super().__setattr__(name, value)
# TODO torch.Long/bool
def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData':
"""
Args:
item (Sequence[Union[int, slice]]): Get the corresponding values
according to item.
Returns:
:obj:`PixelData`: Corresponding values.
"""
new_data = self.__class__(metainfo=self.metainfo)
if isinstance(item, tuple):
assert len(item) == 2, 'Only support to slice height and width'
tmp_item: List[slice] = list()
for index, single_item in enumerate(item[::-1]):
if isinstance(single_item, int):
tmp_item.insert(
0, slice(single_item, None, self.shape[-index - 1]))
elif isinstance(single_item, slice):
tmp_item.insert(0, single_item)
else:
raise TypeError(
'The type of element in input must be int or slice, '
f'but got {type(single_item)}')
tmp_item.insert(0, slice(None, None, None))
item = tuple(tmp_item)
for k, v in self.items():
setattr(new_data, k, v[item])
else:
raise TypeError(
f'Unsupported type {type(item)} for slicing PixelData')
return new_data
@property
def shape(self):
"""The shape of pixel data."""
if len(self._data_fields) > 0:
return tuple(self.values()[0].shape[-2:])
else:
return None
# TODO padding, resize