rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame contribute delete
No virus
1.65 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .base_data_element import BaseDataElement
class LabelData(BaseDataElement):
"""Data structure for label-level annotations or predictions."""
@staticmethod
def onehot_to_label(onehot: torch.Tensor) -> torch.Tensor:
"""Convert the one-hot input to label.
Args:
onehot (torch.Tensor, optional): The one-hot input. The format
of input must be one-hot.
Returns:
torch.Tensor: The converted results.
"""
assert isinstance(onehot, torch.Tensor)
if (onehot.ndim == 1 and onehot.max().item() <= 1
and onehot.min().item() >= 0):
return onehot.nonzero().squeeze(-1)
else:
raise ValueError(
'input is not one-hot and can not convert to label')
@staticmethod
def label_to_onehot(label: torch.Tensor, num_classes: int) -> torch.Tensor:
"""Convert the label-format input to one-hot.
Args:
label (torch.Tensor): The label-format input. The format
of item must be label-format.
num_classes (int): The number of classes.
Returns:
torch.Tensor: The converted results.
"""
assert isinstance(label, torch.Tensor)
onehot = label.new_zeros((num_classes, ))
assert max(label, default=torch.tensor(0)).item() < num_classes
onehot[label] = 1
return onehot