# Copyright (c) Facebook, Inc. and its affiliates. import logging from typing import List, Optional, Sequence, Tuple import torch from detectron2.layers.nms import batched_nms from detectron2.structures.instances import Instances from densepose.converters import ToChartResultConverterWithConfidences from densepose.structures import ( DensePoseChartResultWithConfidences, DensePoseEmbeddingPredictorOutput, ) from densepose.vis.bounding_box import BoundingBoxVisualizer, ScoredBoundingBoxVisualizer from densepose.vis.densepose_outputs_vertex import DensePoseOutputsVertexVisualizer from densepose.vis.densepose_results import DensePoseResultsVisualizer from .base import CompoundVisualizer Scores = Sequence[float] DensePoseChartResultsWithConfidences = List[DensePoseChartResultWithConfidences] def extract_scores_from_instances(instances: Instances, select=None): if instances.has("scores"): return instances.scores if select is None else instances.scores[select] return None def extract_boxes_xywh_from_instances(instances: Instances, select=None): if instances.has("pred_boxes"): boxes_xywh = instances.pred_boxes.tensor.clone() boxes_xywh[:, 2] -= boxes_xywh[:, 0] boxes_xywh[:, 3] -= boxes_xywh[:, 1] return boxes_xywh if select is None else boxes_xywh[select] return None def create_extractor(visualizer: object): """ Create an extractor for the provided visualizer """ if isinstance(visualizer, CompoundVisualizer): extractors = [create_extractor(v) for v in visualizer.visualizers] return CompoundExtractor(extractors) elif isinstance(visualizer, DensePoseResultsVisualizer): return DensePoseResultExtractor() elif isinstance(visualizer, ScoredBoundingBoxVisualizer): return CompoundExtractor([extract_boxes_xywh_from_instances, extract_scores_from_instances]) elif isinstance(visualizer, BoundingBoxVisualizer): return extract_boxes_xywh_from_instances elif isinstance(visualizer, DensePoseOutputsVertexVisualizer): return DensePoseOutputsExtractor() else: logger = logging.getLogger(__name__) logger.error(f"Could not create extractor for {visualizer}") return None class BoundingBoxExtractor: """ Extracts bounding boxes from instances """ def __call__(self, instances: Instances): boxes_xywh = extract_boxes_xywh_from_instances(instances) return boxes_xywh class ScoredBoundingBoxExtractor: """ Extracts bounding boxes from instances """ def __call__(self, instances: Instances, select=None): scores = extract_scores_from_instances(instances) boxes_xywh = extract_boxes_xywh_from_instances(instances) if (scores is None) or (boxes_xywh is None): return (boxes_xywh, scores) if select is not None: scores = scores[select] boxes_xywh = boxes_xywh[select] return (boxes_xywh, scores) class DensePoseResultExtractor: """ Extracts DensePose chart result with confidences from instances """ def __call__( self, instances: Instances, select=None ) -> Tuple[Optional[DensePoseChartResultsWithConfidences], Optional[torch.Tensor]]: if instances.has("pred_densepose") and instances.has("pred_boxes"): dpout = instances.pred_densepose boxes_xyxy = instances.pred_boxes boxes_xywh = extract_boxes_xywh_from_instances(instances) if select is not None: dpout = dpout[select] boxes_xyxy = boxes_xyxy[select] converter = ToChartResultConverterWithConfidences() results = [converter.convert(dpout[i], boxes_xyxy[[i]]) for i in range(len(dpout))] return results, boxes_xywh else: return None, None class DensePoseOutputsExtractor: """ Extracts DensePose result from instances """ def __call__( self, instances: Instances, select=None, ) -> Tuple[ Optional[DensePoseEmbeddingPredictorOutput], Optional[torch.Tensor], Optional[List[int]] ]: if not (instances.has("pred_densepose") and instances.has("pred_boxes")): return None, None, None dpout = instances.pred_densepose boxes_xyxy = instances.pred_boxes boxes_xywh = extract_boxes_xywh_from_instances(instances) if instances.has("pred_classes"): classes = instances.pred_classes.tolist() else: classes = None if select is not None: dpout = dpout[select] boxes_xyxy = boxes_xyxy[select] if classes is not None: classes = classes[select] return dpout, boxes_xywh, classes class CompoundExtractor: """ Extracts data for CompoundVisualizer """ def __init__(self, extractors): self.extractors = extractors def __call__(self, instances: Instances, select=None): datas = [] for extractor in self.extractors: data = extractor(instances, select) datas.append(data) return datas class NmsFilteredExtractor: """ Extracts data in the format accepted by NmsFilteredVisualizer """ def __init__(self, extractor, iou_threshold): self.extractor = extractor self.iou_threshold = iou_threshold def __call__(self, instances: Instances, select=None): scores = extract_scores_from_instances(instances) boxes_xywh = extract_boxes_xywh_from_instances(instances) if boxes_xywh is None: return None select_local_idx = batched_nms( boxes_xywh, scores, torch.zeros(len(scores), dtype=torch.int32), iou_threshold=self.iou_threshold, ).squeeze() select_local = torch.zeros(len(boxes_xywh), dtype=torch.bool, device=boxes_xywh.device) select_local[select_local_idx] = True select = select_local if select is None else (select & select_local) return self.extractor(instances, select=select) class ScoreThresholdedExtractor: """ Extracts data in the format accepted by ScoreThresholdedVisualizer """ def __init__(self, extractor, min_score): self.extractor = extractor self.min_score = min_score def __call__(self, instances: Instances, select=None): scores = extract_scores_from_instances(instances) if scores is None: return None select_local = scores > self.min_score select = select_local if select is None else (select & select_local) data = self.extractor(instances, select=select) return data