|
|
|
import numpy as np |
|
from typing import List |
|
import pycocotools.mask as mask_util |
|
|
|
from detectron2.structures import Instances |
|
from detectron2.utils.visualizer import ( |
|
ColorMode, |
|
Visualizer, |
|
_create_text_labels, |
|
_PanopticPrediction, |
|
) |
|
|
|
from .colormap import random_color, random_colors |
|
|
|
|
|
class _DetectedInstance: |
|
""" |
|
Used to store data about detected objects in video frame, |
|
in order to transfer color to objects in the future frames. |
|
|
|
Attributes: |
|
label (int): |
|
bbox (tuple[float]): |
|
mask_rle (dict): |
|
color (tuple[float]): RGB colors in range (0, 1) |
|
ttl (int): time-to-live for the instance. For example, if ttl=2, |
|
the instance color can be transferred to objects in the next two frames. |
|
""" |
|
|
|
__slots__ = ["label", "bbox", "mask_rle", "color", "ttl"] |
|
|
|
def __init__(self, label, bbox, mask_rle, color, ttl): |
|
self.label = label |
|
self.bbox = bbox |
|
self.mask_rle = mask_rle |
|
self.color = color |
|
self.ttl = ttl |
|
|
|
|
|
class VideoVisualizer: |
|
def __init__(self, metadata, instance_mode=ColorMode.IMAGE): |
|
""" |
|
Args: |
|
metadata (MetadataCatalog): image metadata. |
|
""" |
|
self.metadata = metadata |
|
self._old_instances = [] |
|
assert instance_mode in [ |
|
ColorMode.IMAGE, |
|
ColorMode.IMAGE_BW, |
|
], "Other mode not supported yet." |
|
self._instance_mode = instance_mode |
|
self._max_num_instances = self.metadata.get("max_num_instances", 74) |
|
self._assigned_colors = {} |
|
self._color_pool = random_colors(self._max_num_instances, rgb=True, maximum=1) |
|
self._color_idx_set = set(range(len(self._color_pool))) |
|
|
|
def draw_instance_predictions(self, frame, predictions): |
|
""" |
|
Draw instance-level prediction results on an image. |
|
|
|
Args: |
|
frame (ndarray): an RGB image of shape (H, W, C), in the range [0, 255]. |
|
predictions (Instances): the output of an instance detection/segmentation |
|
model. Following fields will be used to draw: |
|
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). |
|
|
|
Returns: |
|
output (VisImage): image object with visualizations. |
|
""" |
|
frame_visualizer = Visualizer(frame, self.metadata) |
|
num_instances = len(predictions) |
|
if num_instances == 0: |
|
return frame_visualizer.output |
|
|
|
boxes = predictions.pred_boxes.tensor.numpy() if predictions.has("pred_boxes") else None |
|
scores = predictions.scores if predictions.has("scores") else None |
|
classes = predictions.pred_classes.numpy() if predictions.has("pred_classes") else None |
|
keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None |
|
colors = predictions.COLOR if predictions.has("COLOR") else [None] * len(predictions) |
|
periods = predictions.ID_period if predictions.has("ID_period") else None |
|
period_threshold = self.metadata.get("period_threshold", 0) |
|
visibilities = ( |
|
[True] * len(predictions) |
|
if periods is None |
|
else [x > period_threshold for x in periods] |
|
) |
|
|
|
if predictions.has("pred_masks"): |
|
masks = predictions.pred_masks |
|
|
|
|
|
|
|
else: |
|
masks = None |
|
|
|
if not predictions.has("COLOR"): |
|
if predictions.has("ID"): |
|
colors = self._assign_colors_by_id(predictions) |
|
else: |
|
|
|
detected = [ |
|
_DetectedInstance(classes[i], boxes[i], mask_rle=None, color=colors[i], ttl=8) |
|
for i in range(num_instances) |
|
] |
|
colors = self._assign_colors(detected) |
|
|
|
labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) |
|
|
|
if self._instance_mode == ColorMode.IMAGE_BW: |
|
|
|
frame_visualizer.output.reset_image( |
|
frame_visualizer._create_grayscale_image( |
|
(masks.any(dim=0) > 0).numpy() if masks is not None else None |
|
) |
|
) |
|
alpha = 0.3 |
|
else: |
|
alpha = 0.5 |
|
|
|
labels = ( |
|
None |
|
if labels is None |
|
else [y[0] for y in filter(lambda x: x[1], zip(labels, visibilities))] |
|
) |
|
assigned_colors = ( |
|
None |
|
if colors is None |
|
else [y[0] for y in filter(lambda x: x[1], zip(colors, visibilities))] |
|
) |
|
frame_visualizer.overlay_instances( |
|
boxes=None if masks is not None else boxes[visibilities], |
|
masks=None if masks is None else masks[visibilities], |
|
labels=labels, |
|
keypoints=None if keypoints is None else keypoints[visibilities], |
|
assigned_colors=assigned_colors, |
|
alpha=alpha, |
|
) |
|
|
|
return frame_visualizer.output |
|
|
|
def draw_sem_seg(self, frame, sem_seg, area_threshold=None): |
|
""" |
|
Args: |
|
sem_seg (ndarray or Tensor): semantic segmentation of shape (H, W), |
|
each value is the integer label. |
|
area_threshold (Optional[int]): only draw segmentations larger than the threshold |
|
""" |
|
|
|
frame_visualizer = Visualizer(frame, self.metadata) |
|
frame_visualizer.draw_sem_seg(sem_seg, area_threshold=None) |
|
return frame_visualizer.output |
|
|
|
def draw_panoptic_seg_predictions( |
|
self, frame, panoptic_seg, segments_info, area_threshold=None, alpha=0.5 |
|
): |
|
frame_visualizer = Visualizer(frame, self.metadata) |
|
pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) |
|
|
|
if self._instance_mode == ColorMode.IMAGE_BW: |
|
frame_visualizer.output.reset_image( |
|
frame_visualizer._create_grayscale_image(pred.non_empty_mask()) |
|
) |
|
|
|
|
|
for mask, sinfo in pred.semantic_masks(): |
|
category_idx = sinfo["category_id"] |
|
try: |
|
mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] |
|
except AttributeError: |
|
mask_color = None |
|
|
|
frame_visualizer.draw_binary_mask( |
|
mask, |
|
color=mask_color, |
|
text=self.metadata.stuff_classes[category_idx], |
|
alpha=alpha, |
|
area_threshold=area_threshold, |
|
) |
|
|
|
all_instances = list(pred.instance_masks()) |
|
if len(all_instances) == 0: |
|
return frame_visualizer.output |
|
|
|
masks, sinfo = list(zip(*all_instances)) |
|
num_instances = len(masks) |
|
masks_rles = mask_util.encode( |
|
np.asarray(np.asarray(masks).transpose(1, 2, 0), dtype=np.uint8, order="F") |
|
) |
|
assert len(masks_rles) == num_instances |
|
|
|
category_ids = [x["category_id"] for x in sinfo] |
|
detected = [ |
|
_DetectedInstance(category_ids[i], bbox=None, mask_rle=masks_rles[i], color=None, ttl=8) |
|
for i in range(num_instances) |
|
] |
|
colors = self._assign_colors(detected) |
|
labels = [self.metadata.thing_classes[k] for k in category_ids] |
|
|
|
frame_visualizer.overlay_instances( |
|
boxes=None, |
|
masks=masks, |
|
labels=labels, |
|
keypoints=None, |
|
assigned_colors=colors, |
|
alpha=alpha, |
|
) |
|
return frame_visualizer.output |
|
|
|
def _assign_colors(self, instances): |
|
""" |
|
Naive tracking heuristics to assign same color to the same instance, |
|
will update the internal state of tracked instances. |
|
|
|
Returns: |
|
list[tuple[float]]: list of colors. |
|
""" |
|
|
|
|
|
is_crowd = np.zeros((len(instances),), dtype=bool) |
|
if instances[0].bbox is None: |
|
assert instances[0].mask_rle is not None |
|
|
|
|
|
rles_old = [x.mask_rle for x in self._old_instances] |
|
rles_new = [x.mask_rle for x in instances] |
|
ious = mask_util.iou(rles_old, rles_new, is_crowd) |
|
threshold = 0.5 |
|
else: |
|
boxes_old = [x.bbox for x in self._old_instances] |
|
boxes_new = [x.bbox for x in instances] |
|
ious = mask_util.iou(boxes_old, boxes_new, is_crowd) |
|
threshold = 0.6 |
|
if len(ious) == 0: |
|
ious = np.zeros((len(self._old_instances), len(instances)), dtype="float32") |
|
|
|
|
|
for old_idx, old in enumerate(self._old_instances): |
|
for new_idx, new in enumerate(instances): |
|
if old.label != new.label: |
|
ious[old_idx, new_idx] = 0 |
|
|
|
matched_new_per_old = np.asarray(ious).argmax(axis=1) |
|
max_iou_per_old = np.asarray(ious).max(axis=1) |
|
|
|
|
|
extra_instances = [] |
|
for idx, inst in enumerate(self._old_instances): |
|
if max_iou_per_old[idx] > threshold: |
|
newidx = matched_new_per_old[idx] |
|
if instances[newidx].color is None: |
|
instances[newidx].color = inst.color |
|
continue |
|
|
|
|
|
inst.ttl -= 1 |
|
if inst.ttl > 0: |
|
extra_instances.append(inst) |
|
|
|
|
|
for inst in instances: |
|
if inst.color is None: |
|
inst.color = random_color(rgb=True, maximum=1) |
|
self._old_instances = instances[:] + extra_instances |
|
return [d.color for d in instances] |
|
|
|
def _assign_colors_by_id(self, instances: Instances) -> List: |
|
colors = [] |
|
untracked_ids = set(self._assigned_colors.keys()) |
|
for id in instances.ID: |
|
if id in self._assigned_colors: |
|
colors.append(self._color_pool[self._assigned_colors[id]]) |
|
untracked_ids.remove(id) |
|
else: |
|
assert ( |
|
len(self._color_idx_set) >= 1 |
|
), f"Number of id exceeded maximum, \ |
|
max = {self._max_num_instances}" |
|
idx = self._color_idx_set.pop() |
|
color = self._color_pool[idx] |
|
self._assigned_colors[id] = idx |
|
colors.append(color) |
|
for id in untracked_ids: |
|
self._color_idx_set.add(self._assigned_colors[id]) |
|
del self._assigned_colors[id] |
|
return colors |
|
|