# Copyright (c) Facebook, Inc. and its affiliates. from typing import Any, Dict, List, Tuple import torch from torch.nn import functional as F from detectron2.config import CfgNode from detectron2.structures import Instances from densepose.converters.base import IntTupleBox from densepose.data.utils import get_class_to_mesh_name_mapping from densepose.modeling.cse.utils import squared_euclidean_distance_matrix from densepose.structures import DensePoseDataRelative from .densepose_base import DensePoseBaseSampler class DensePoseCSEBaseSampler(DensePoseBaseSampler): """ Base DensePose sampler to produce DensePose data from DensePose predictions. Samples for each class are drawn according to some distribution over all pixels estimated to belong to that class. """ def __init__( self, cfg: CfgNode, use_gt_categories: bool, embedder: torch.nn.Module, count_per_class: int = 8, ): """ Constructor Args: cfg (CfgNode): the config of the model embedder (torch.nn.Module): necessary to compute mesh vertex embeddings count_per_class (int): the sampler produces at most `count_per_class` samples for each category """ super().__init__(count_per_class) self.embedder = embedder self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) self.use_gt_categories = use_gt_categories def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]: """ Sample DensPoseDataRelative from estimation results """ if self.use_gt_categories: instance_class = instance.dataset_classes.tolist()[0] else: instance_class = instance.pred_classes.tolist()[0] mesh_name = self.class_to_mesh_name[instance_class] annotation = { DensePoseDataRelative.X_KEY: [], DensePoseDataRelative.Y_KEY: [], DensePoseDataRelative.VERTEX_IDS_KEY: [], DensePoseDataRelative.MESH_NAME_KEY: mesh_name, } mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh) indices = torch.nonzero(mask, as_tuple=True) selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu() values = other_values[:, indices[0], indices[1]] k = values.shape[1] count = min(self.count_per_class, k) if count <= 0: return annotation index_sample = self._produce_index_sample(values, count) closest_vertices = squared_euclidean_distance_matrix( selected_embeddings[index_sample], self.embedder(mesh_name) ) closest_vertices = torch.argmin(closest_vertices, dim=1) sampled_y = indices[0][index_sample] + 0.5 sampled_x = indices[1][index_sample] + 0.5 # prepare / normalize data _, _, w, h = bbox_xywh x = (sampled_x / w * 256.0).cpu().tolist() y = (sampled_y / h * 256.0).cpu().tolist() # extend annotations annotation[DensePoseDataRelative.X_KEY].extend(x) annotation[DensePoseDataRelative.Y_KEY].extend(y) annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist()) return annotation def _produce_mask_and_results( self, instance: Instances, bbox_xywh: IntTupleBox ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Method to get labels and DensePose results from an instance Args: instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput` bbox_xywh (IntTupleBox): the corresponding bounding box Return: mask (torch.Tensor): shape [H, W], DensePose segmentation mask embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W], DensePose CSE Embeddings other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W], for potential other values """ densepose_output = instance.pred_densepose S = densepose_output.coarse_segm E = densepose_output.embedding _, _, w, h = bbox_xywh embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0] coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0] mask = coarse_segm_resized.argmax(0) > 0 other_values = torch.empty((0, h, w), device=E.device) return mask, embeddings, other_values def _resample_mask(self, output: Any) -> torch.Tensor: """ Convert DensePose predictor output to segmentation annotation - tensors of size (256, 256) and type `int64`. Args: output: DensePose predictor output with the following attributes: - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse segmentation scores Return: Tensor of size (S, S) and type `int64` with coarse segmentation annotations, where S = DensePoseDataRelative.MASK_SIZE """ sz = DensePoseDataRelative.MASK_SIZE mask = ( F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False) .argmax(dim=1) .long() .squeeze() .cpu() ) return mask