# Copyright (c) Facebook, Inc. and its affiliates. import random from typing import Optional, Tuple import torch from torch.nn import functional as F from detectron2.config import CfgNode from detectron2.structures import Instances from densepose.converters.base import IntTupleBox from .densepose_cse_base import DensePoseCSEBaseSampler class DensePoseCSEConfidenceBasedSampler(DensePoseCSEBaseSampler): """ Samples DensePose data from DensePose predictions. Samples for each class are drawn using confidence value estimates. """ def __init__( self, cfg: CfgNode, use_gt_categories: bool, embedder: torch.nn.Module, confidence_channel: str, count_per_class: int = 8, search_count_multiplier: Optional[float] = None, search_proportion: Optional[float] = None, ): """ Constructor Args: cfg (CfgNode): the config of the model embedder (torch.nn.Module): necessary to compute mesh vertex embeddings confidence_channel (str): confidence channel to use for sampling; possible values: "coarse_segm_confidence": confidences for coarse segmentation (default: "coarse_segm_confidence") count_per_class (int): the sampler produces at most `count_per_class` samples for each category (default: 8) search_count_multiplier (float or None): if not None, the total number of the most confident estimates of a given class to consider is defined as `min(search_count_multiplier * count_per_class, N)`, where `N` is the total number of estimates of the class; cannot be specified together with `search_proportion` (default: None) search_proportion (float or None): if not None, the total number of the of the most confident estimates of a given class to consider is defined as `min(max(search_proportion * N, count_per_class), N)`, where `N` is the total number of estimates of the class; cannot be specified together with `search_count_multiplier` (default: None) """ super().__init__(cfg, use_gt_categories, embedder, count_per_class) self.confidence_channel = confidence_channel self.search_count_multiplier = search_count_multiplier self.search_proportion = search_proportion assert (search_count_multiplier is None) or (search_proportion is None), ( f"Cannot specify both search_count_multiplier (={search_count_multiplier})" f"and search_proportion (={search_proportion})" ) def _produce_index_sample(self, values: torch.Tensor, count: int): """ Produce a sample of indices to select data based on confidences Args: values (torch.Tensor): a tensor of length k that contains confidences k: number of points labeled with part_id count (int): number of samples to produce, should be positive and <= k Return: list(int): indices of values (along axis 1) selected as a sample """ k = values.shape[1] if k == count: index_sample = list(range(k)) else: # take the best count * search_count_multiplier pixels, # sample from them uniformly # (here best = smallest variance) _, sorted_confidence_indices = torch.sort(values[0]) if self.search_count_multiplier is not None: search_count = min(int(count * self.search_count_multiplier), k) elif self.search_proportion is not None: search_count = min(max(int(k * self.search_proportion), count), k) else: search_count = min(count, k) sample_from_top = random.sample(range(search_count), count) index_sample = sorted_confidence_indices[-search_count:][sample_from_top] return index_sample def _produce_mask_and_results( self, instance: Instances, bbox_xywh: IntTupleBox ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Method to get labels and DensePose results from an instance Args: instance (Instances): an instance of `DensePoseEmbeddingPredictorOutputWithConfidences` bbox_xywh (IntTupleBox): the corresponding bounding box Return: mask (torch.Tensor): shape [H, W], DensePose segmentation mask embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W] DensePose CSE Embeddings other_values: a tensor of shape [1, H, W], DensePose CSE confidence """ _, _, w, h = bbox_xywh densepose_output = instance.pred_densepose mask, embeddings, _ = super()._produce_mask_and_results(instance, bbox_xywh) other_values = F.interpolate( getattr(densepose_output, self.confidence_channel), size=(h, w), mode="bilinear", )[0].cpu() return mask, embeddings, other_values