# Copyright (c) Facebook, Inc. and its affiliates. from dataclasses import dataclass from typing import Union import torch @dataclass class DensePoseChartPredictorOutput: """ Predictor output that contains segmentation and inner coordinates predictions for predefined body parts: * coarse segmentation, a tensor of shape [N, K, Hout, Wout] * fine segmentation, a tensor of shape [N, C, Hout, Wout] * U coordinates, a tensor of shape [N, C, Hout, Wout] * V coordinates, a tensor of shape [N, C, Hout, Wout] where - N is the number of instances - K is the number of coarse segmentation channels ( 2 = foreground / background, 15 = one of 14 body parts / background) - C is the number of fine segmentation channels ( 24 fine body parts / background) - Hout and Wout are height and width of predictions """ coarse_segm: torch.Tensor fine_segm: torch.Tensor u: torch.Tensor v: torch.Tensor def __len__(self): """ Number of instances (N) in the output """ return self.coarse_segm.size(0) def __getitem__( self, item: Union[int, slice, torch.BoolTensor] ) -> "DensePoseChartPredictorOutput": """ Get outputs for the selected instance(s) Args: item (int or slice or tensor): selected items """ if isinstance(item, int): return DensePoseChartPredictorOutput( coarse_segm=self.coarse_segm[item].unsqueeze(0), fine_segm=self.fine_segm[item].unsqueeze(0), u=self.u[item].unsqueeze(0), v=self.v[item].unsqueeze(0), ) else: return DensePoseChartPredictorOutput( coarse_segm=self.coarse_segm[item], fine_segm=self.fine_segm[item], u=self.u[item], v=self.v[item], ) def to(self, device: torch.device): """ Transfers all tensors to the given device """ coarse_segm = self.coarse_segm.to(device) fine_segm = self.fine_segm.to(device) u = self.u.to(device) v = self.v.to(device) return DensePoseChartPredictorOutput(coarse_segm=coarse_segm, fine_segm=fine_segm, u=u, v=v)