|
|
|
import logging |
|
from typing import List, Optional, Sequence, Tuple |
|
import torch |
|
|
|
from detectron2.layers.nms import batched_nms |
|
from detectron2.structures.instances import Instances |
|
|
|
from densepose.converters import ToChartResultConverterWithConfidences |
|
from densepose.structures import ( |
|
DensePoseChartResultWithConfidences, |
|
DensePoseEmbeddingPredictorOutput, |
|
) |
|
from densepose.vis.bounding_box import BoundingBoxVisualizer, ScoredBoundingBoxVisualizer |
|
from densepose.vis.densepose_outputs_vertex import DensePoseOutputsVertexVisualizer |
|
from densepose.vis.densepose_results import DensePoseResultsVisualizer |
|
|
|
from .base import CompoundVisualizer |
|
|
|
Scores = Sequence[float] |
|
DensePoseChartResultsWithConfidences = List[DensePoseChartResultWithConfidences] |
|
|
|
|
|
def extract_scores_from_instances(instances: Instances, select=None): |
|
if instances.has("scores"): |
|
return instances.scores if select is None else instances.scores[select] |
|
return None |
|
|
|
|
|
def extract_boxes_xywh_from_instances(instances: Instances, select=None): |
|
if instances.has("pred_boxes"): |
|
boxes_xywh = instances.pred_boxes.tensor.clone() |
|
boxes_xywh[:, 2] -= boxes_xywh[:, 0] |
|
boxes_xywh[:, 3] -= boxes_xywh[:, 1] |
|
return boxes_xywh if select is None else boxes_xywh[select] |
|
return None |
|
|
|
|
|
def create_extractor(visualizer: object): |
|
""" |
|
Create an extractor for the provided visualizer |
|
""" |
|
if isinstance(visualizer, CompoundVisualizer): |
|
extractors = [create_extractor(v) for v in visualizer.visualizers] |
|
return CompoundExtractor(extractors) |
|
elif isinstance(visualizer, DensePoseResultsVisualizer): |
|
return DensePoseResultExtractor() |
|
elif isinstance(visualizer, ScoredBoundingBoxVisualizer): |
|
return CompoundExtractor([extract_boxes_xywh_from_instances, extract_scores_from_instances]) |
|
elif isinstance(visualizer, BoundingBoxVisualizer): |
|
return extract_boxes_xywh_from_instances |
|
elif isinstance(visualizer, DensePoseOutputsVertexVisualizer): |
|
return DensePoseOutputsExtractor() |
|
else: |
|
logger = logging.getLogger(__name__) |
|
logger.error(f"Could not create extractor for {visualizer}") |
|
return None |
|
|
|
|
|
class BoundingBoxExtractor: |
|
""" |
|
Extracts bounding boxes from instances |
|
""" |
|
|
|
def __call__(self, instances: Instances): |
|
boxes_xywh = extract_boxes_xywh_from_instances(instances) |
|
return boxes_xywh |
|
|
|
|
|
class ScoredBoundingBoxExtractor: |
|
""" |
|
Extracts bounding boxes from instances |
|
""" |
|
|
|
def __call__(self, instances: Instances, select=None): |
|
scores = extract_scores_from_instances(instances) |
|
boxes_xywh = extract_boxes_xywh_from_instances(instances) |
|
if (scores is None) or (boxes_xywh is None): |
|
return (boxes_xywh, scores) |
|
if select is not None: |
|
scores = scores[select] |
|
boxes_xywh = boxes_xywh[select] |
|
return (boxes_xywh, scores) |
|
|
|
|
|
class DensePoseResultExtractor: |
|
""" |
|
Extracts DensePose chart result with confidences from instances |
|
""" |
|
|
|
def __call__( |
|
self, instances: Instances, select=None |
|
) -> Tuple[Optional[DensePoseChartResultsWithConfidences], Optional[torch.Tensor]]: |
|
if instances.has("pred_densepose") and instances.has("pred_boxes"): |
|
dpout = instances.pred_densepose |
|
boxes_xyxy = instances.pred_boxes |
|
boxes_xywh = extract_boxes_xywh_from_instances(instances) |
|
if select is not None: |
|
dpout = dpout[select] |
|
boxes_xyxy = boxes_xyxy[select] |
|
converter = ToChartResultConverterWithConfidences() |
|
results = [converter.convert(dpout[i], boxes_xyxy[[i]]) for i in range(len(dpout))] |
|
return results, boxes_xywh |
|
else: |
|
return None, None |
|
|
|
|
|
class DensePoseOutputsExtractor: |
|
""" |
|
Extracts DensePose result from instances |
|
""" |
|
|
|
def __call__( |
|
self, |
|
instances: Instances, |
|
select=None, |
|
) -> Tuple[ |
|
Optional[DensePoseEmbeddingPredictorOutput], Optional[torch.Tensor], Optional[List[int]] |
|
]: |
|
if not (instances.has("pred_densepose") and instances.has("pred_boxes")): |
|
return None, None, None |
|
|
|
dpout = instances.pred_densepose |
|
boxes_xyxy = instances.pred_boxes |
|
boxes_xywh = extract_boxes_xywh_from_instances(instances) |
|
|
|
if instances.has("pred_classes"): |
|
classes = instances.pred_classes.tolist() |
|
else: |
|
classes = None |
|
|
|
if select is not None: |
|
dpout = dpout[select] |
|
boxes_xyxy = boxes_xyxy[select] |
|
if classes is not None: |
|
classes = classes[select] |
|
|
|
return dpout, boxes_xywh, classes |
|
|
|
|
|
class CompoundExtractor: |
|
""" |
|
Extracts data for CompoundVisualizer |
|
""" |
|
|
|
def __init__(self, extractors): |
|
self.extractors = extractors |
|
|
|
def __call__(self, instances: Instances, select=None): |
|
datas = [] |
|
for extractor in self.extractors: |
|
data = extractor(instances, select) |
|
datas.append(data) |
|
return datas |
|
|
|
|
|
class NmsFilteredExtractor: |
|
""" |
|
Extracts data in the format accepted by NmsFilteredVisualizer |
|
""" |
|
|
|
def __init__(self, extractor, iou_threshold): |
|
self.extractor = extractor |
|
self.iou_threshold = iou_threshold |
|
|
|
def __call__(self, instances: Instances, select=None): |
|
scores = extract_scores_from_instances(instances) |
|
boxes_xywh = extract_boxes_xywh_from_instances(instances) |
|
if boxes_xywh is None: |
|
return None |
|
select_local_idx = batched_nms( |
|
boxes_xywh, |
|
scores, |
|
torch.zeros(len(scores), dtype=torch.int32), |
|
iou_threshold=self.iou_threshold, |
|
).squeeze() |
|
select_local = torch.zeros(len(boxes_xywh), dtype=torch.bool, device=boxes_xywh.device) |
|
select_local[select_local_idx] = True |
|
select = select_local if select is None else (select & select_local) |
|
return self.extractor(instances, select=select) |
|
|
|
|
|
class ScoreThresholdedExtractor: |
|
""" |
|
Extracts data in the format accepted by ScoreThresholdedVisualizer |
|
""" |
|
|
|
def __init__(self, extractor, min_score): |
|
self.extractor = extractor |
|
self.min_score = min_score |
|
|
|
def __call__(self, instances: Instances, select=None): |
|
scores = extract_scores_from_instances(instances) |
|
if scores is None: |
|
return None |
|
select_local = scores > self.min_score |
|
select = select_local if select is None else (select & select_local) |
|
data = self.extractor(instances, select=select) |
|
return data |
|
|