Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
# Note: This file has been barrowed from facebookresearch/slowfast repo. And it is used to add the bounding boxes and predictions to the frame. | |
# TODO: Migrate this into the core PyTorchVideo libarary. | |
from __future__ import annotations | |
import itertools | |
# import logging | |
from types import SimpleNamespace | |
from typing import Dict, List, Optional, Tuple, Union | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from detectron2.utils.visualizer import Visualizer | |
# logger = logging.getLogger(__name__) | |
def _create_text_labels( | |
classes: List[int], | |
scores: List[float], | |
class_names: List[str], | |
ground_truth: bool = False, | |
) -> List[str]: | |
""" | |
Create text labels. | |
Args: | |
classes (list[int]): a list of class ids for each example. | |
scores (list[float] or None): list of scores for each example. | |
class_names (list[str]): a list of class names, ordered by their ids. | |
ground_truth (bool): whether the labels are ground truth. | |
Returns: | |
labels (list[str]): formatted text labels. | |
""" | |
try: | |
labels = [class_names.get(c, "n/a") for c in classes] | |
except IndexError: | |
# logger.error("Class indices get out of range: {}".format(classes)) | |
return None | |
if ground_truth: | |
labels = ["[{}] {}".format("GT", label) for label in labels] | |
elif scores is not None: | |
assert len(classes) == len(scores) | |
labels = ["[{:.2f}] {}".format(s, label) for s, label in zip(scores, labels)] | |
return labels | |
class ImgVisualizer(Visualizer): | |
def __init__( | |
self, img_rgb: torch.Tensor, meta: Optional[SimpleNamespace] = None, **kwargs | |
) -> None: | |
""" | |
See https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py | |
for more details. | |
Args: | |
img_rgb: a tensor or numpy array of shape (H, W, C), where H and W correspond to | |
the height and width of the image respectively. C is the number of | |
color channels. The image is required to be in RGB format since that | |
is a requirement of the Matplotlib library. The image is also expected | |
to be in the range [0, 255]. | |
meta (MetadataCatalog): image metadata. | |
See https://github.com/facebookresearch/detectron2/blob/81d5a87763bfc71a492b5be89b74179bd7492f6b/detectron2/data/catalog.py#L90 | |
""" | |
super(ImgVisualizer, self).__init__(img_rgb, meta, **kwargs) | |
def draw_text( | |
self, | |
text: str, | |
position: List[int], | |
*, | |
font_size: Optional[int] = None, | |
color: str = "w", | |
horizontal_alignment: str = "center", | |
vertical_alignment: str = "bottom", | |
box_facecolor: str = "black", | |
alpha: float = 0.5, | |
) -> None: | |
""" | |
Draw text at the specified position. | |
Args: | |
text (str): the text to draw on image. | |
position (list of 2 ints): the x,y coordinate to place the text. | |
font_size (Optional[int]): font of the text. If not provided, a font size | |
proportional to the image width is calculated and used. | |
color (str): color of the text. Refer to `matplotlib.colors` for full list | |
of formats that are accepted. | |
horizontal_alignment (str): see `matplotlib.text.Text`. | |
vertical_alignment (str): see `matplotlib.text.Text`. | |
box_facecolor (str): color of the box wrapped around the text. Refer to | |
`matplotlib.colors` for full list of formats that are accepted. | |
alpha (float): transparency level of the box. | |
""" | |
if not font_size: | |
font_size = self._default_font_size | |
x, y = position | |
self.output.ax.text( | |
x, | |
y, | |
text, | |
size=font_size * self.output.scale, | |
family="monospace", | |
bbox={ | |
"facecolor": box_facecolor, | |
"alpha": alpha, | |
"pad": 0.7, | |
"edgecolor": "none", | |
}, | |
verticalalignment=vertical_alignment, | |
horizontalalignment=horizontal_alignment, | |
color=color, | |
zorder=10, | |
) | |
def draw_multiple_text( | |
self, | |
text_ls: List[str], | |
box_coordinate: torch.Tensor, | |
*, | |
top_corner: bool = True, | |
font_size: Optional[int] = None, | |
color: str = "w", | |
box_facecolors: str = "black", | |
alpha: float = 0.5, | |
) -> None: | |
""" | |
Draw a list of text labels for some bounding box on the image. | |
Args: | |
text_ls (list of strings): a list of text labels. | |
box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
coordinates of the box. | |
top_corner (bool): If True, draw the text labels at (x_left, y_top) of the box. | |
Else, draw labels at (x_left, y_bottom). | |
font_size (Optional[int]): font of the text. If not provided, a font size | |
proportional to the image width is calculated and used. | |
color (str): color of the text. Refer to `matplotlib.colors` for full list | |
of formats that are accepted. | |
box_facecolors (str): colors of the box wrapped around the text. Refer to | |
`matplotlib.colors` for full list of formats that are accepted. | |
alpha (float): transparency level of the box. | |
""" | |
if not isinstance(box_facecolors, list): | |
box_facecolors = [box_facecolors] * len(text_ls) | |
assert len(box_facecolors) == len( | |
text_ls | |
), "Number of colors provided is not equal to the number of text labels." | |
if not font_size: | |
font_size = self._default_font_size | |
text_box_width = font_size + font_size // 2 | |
# If the texts does not fit in the assigned location, | |
# we split the text and draw it in another place. | |
if top_corner: | |
num_text_split = self._align_y_top( | |
box_coordinate, len(text_ls), text_box_width | |
) | |
y_corner = 1 | |
else: | |
num_text_split = len(text_ls) - self._align_y_bottom( | |
box_coordinate, len(text_ls), text_box_width | |
) | |
y_corner = 3 | |
text_color_sorted = sorted( | |
zip(text_ls, box_facecolors), key=lambda x: x[0], reverse=True | |
) | |
if len(text_color_sorted) != 0: | |
text_ls, box_facecolors = zip(*text_color_sorted) | |
else: | |
text_ls, box_facecolors = [], [] | |
text_ls, box_facecolors = list(text_ls), list(box_facecolors) | |
self.draw_multiple_text_upward( | |
text_ls[:num_text_split][::-1], | |
box_coordinate, | |
y_corner=y_corner, | |
font_size=font_size, | |
color=color, | |
box_facecolors=box_facecolors[:num_text_split][::-1], | |
alpha=alpha, | |
) | |
self.draw_multiple_text_downward( | |
text_ls[num_text_split:], | |
box_coordinate, | |
y_corner=y_corner, | |
font_size=font_size, | |
color=color, | |
box_facecolors=box_facecolors[num_text_split:], | |
alpha=alpha, | |
) | |
def draw_multiple_text_upward( | |
self, | |
text_ls: List[str], | |
box_coordinate: torch.Tensor, | |
*, | |
y_corner: int = 1, | |
font_size: Optional[int] = None, | |
color: str = "w", | |
box_facecolors: str = "black", | |
alpha: float = 0.5, | |
) -> None: | |
""" | |
Draw a list of text labels for some bounding box on the image in upward direction. | |
The next text label will be on top of the previous one. | |
Args: | |
text_ls (list of strings): a list of text labels. | |
box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
coordinates of the box. | |
y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of | |
the box to draw labels around. | |
font_size (Optional[int]): font of the text. If not provided, a font size | |
proportional to the image width is calculated and used. | |
color (str): color of the text. Refer to `matplotlib.colors` for full list | |
of formats that are accepted. | |
box_facecolors (str or list of strs): colors of the box wrapped around the | |
text. Refer to `matplotlib.colors` for full list of formats that | |
are accepted. | |
alpha (float): transparency level of the box. | |
""" | |
if not isinstance(box_facecolors, list): | |
box_facecolors = [box_facecolors] * len(text_ls) | |
assert len(box_facecolors) == len( | |
text_ls | |
), "Number of colors provided is not equal to the number of text labels." | |
assert y_corner in [1, 3], "Y_corner must be either 1 or 3" | |
if not font_size: | |
font_size = self._default_font_size | |
x, horizontal_alignment = self._align_x_coordinate(box_coordinate) | |
y = box_coordinate[y_corner].item() | |
for i, text in enumerate(text_ls): | |
self.draw_text( | |
text, | |
(x, y), | |
font_size=font_size, | |
color=color, | |
horizontal_alignment=horizontal_alignment, | |
vertical_alignment="bottom", | |
box_facecolor=box_facecolors[i], | |
alpha=alpha, | |
) | |
y -= font_size + font_size // 2 | |
def draw_multiple_text_downward( | |
self, | |
text_ls: List[str], | |
box_coordinate: torch.Tensor, | |
*, | |
y_corner: int = 1, | |
font_size: Optional[int] = None, | |
color: str = "w", | |
box_facecolors: str = "black", | |
alpha: float = 0.5, | |
) -> None: | |
""" | |
Draw a list of text labels for some bounding box on the image in downward direction. | |
The next text label will be below the previous one. | |
Args: | |
text_ls (list of strings): a list of text labels. | |
box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
coordinates of the box. | |
y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of | |
the box to draw labels around. | |
font_size (Optional[int]): font of the text. If not provided, a font size | |
proportional to the image width is calculated and used. | |
color (str): color of the text. Refer to `matplotlib.colors` for full list | |
of formats that are accepted. | |
box_facecolors (str): colors of the box wrapped around the text. Refer to | |
`matplotlib.colors` for full list of formats that are accepted. | |
alpha (float): transparency level of the box. | |
""" | |
if not isinstance(box_facecolors, list): | |
box_facecolors = [box_facecolors] * len(text_ls) | |
assert len(box_facecolors) == len( | |
text_ls | |
), "Number of colors provided is not equal to the number of text labels." | |
assert y_corner in [1, 3], "Y_corner must be either 1 or 3" | |
if not font_size: | |
font_size = self._default_font_size | |
x, horizontal_alignment = self._align_x_coordinate(box_coordinate) | |
y = box_coordinate[y_corner].item() | |
for i, text in enumerate(text_ls): | |
self.draw_text( | |
text, | |
(x, y), | |
font_size=font_size, | |
color=color, | |
horizontal_alignment=horizontal_alignment, | |
vertical_alignment="top", | |
box_facecolor=box_facecolors[i], | |
alpha=alpha, | |
) | |
y += font_size + font_size // 2 | |
def _align_x_coordinate(self, box_coordinate: torch.Tensor) -> Tuple[float, str]: | |
""" | |
Choose an x-coordinate from the box to make sure the text label | |
does not go out of frames. By default, the left x-coordinate is | |
chosen and text is aligned left. If the box is too close to the | |
right side of the image, then the right x-coordinate is chosen | |
instead and the text is aligned right. | |
Args: | |
box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
coordinates of the box. | |
Returns: | |
x_coordinate (float): the chosen x-coordinate. | |
alignment (str): whether to align left or right. | |
""" | |
# If the x-coordinate is greater than 5/6 of the image width, | |
# then we align test to the right of the box. This is | |
# chosen by heuristics. | |
if box_coordinate[0] > (self.output.width * 5) // 6: | |
return box_coordinate[2], "right" | |
return box_coordinate[0], "left" | |
def _align_y_top( | |
self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float | |
) -> int: | |
""" | |
Calculate the number of text labels to plot on top of the box | |
without going out of frames. | |
Args: | |
box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
coordinates of the box. | |
num_text (int): the number of text labels to plot. | |
textbox_width (float): the width of the box wrapped around text label. | |
""" | |
dist_to_top = box_coordinate[1] | |
num_text_top = dist_to_top // textbox_width | |
if isinstance(num_text_top, torch.Tensor): | |
num_text_top = int(num_text_top.item()) | |
return min(num_text, num_text_top) | |
def _align_y_bottom( | |
self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float | |
) -> int: | |
""" | |
Calculate the number of text labels to plot at the bottom of the box | |
without going out of frames. | |
Args: | |
box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) | |
coordinates of the box. | |
num_text (int): the number of text labels to plot. | |
textbox_width (float): the width of the box wrapped around text label. | |
""" | |
dist_to_bottom = self.output.height - box_coordinate[3] | |
num_text_bottom = dist_to_bottom // textbox_width | |
if isinstance(num_text_bottom, torch.Tensor): | |
num_text_bottom = int(num_text_bottom.item()) | |
return min(num_text, num_text_bottom) | |
class VideoVisualizer: | |
def __init__( | |
self, | |
num_classes: int, | |
class_names: Dict, | |
top_k: int = 1, | |
colormap: str = "rainbow", | |
thres: float = 0.7, | |
lower_thres: float = 0.3, | |
common_class_names: Optional[List[str]] = None, | |
mode: str = "top-k", | |
) -> None: | |
""" | |
Args: | |
num_classes (int): total number of classes. | |
class_names (dict): Dict mapping classID to name. | |
top_k (int): number of top predicted classes to plot. | |
colormap (str): the colormap to choose color for class labels from. | |
See https://matplotlib.org/tutorials/colors/colormaps.html | |
thres (float): threshold for picking predicted classes to visualize. | |
lower_thres (Optional[float]): If `common_class_names` if given, | |
this `lower_thres` will be applied to uncommon classes and | |
`thres` will be applied to classes in `common_class_names`. | |
common_class_names (Optional[list of str]): list of common class names | |
to apply `thres`. Class names not included in `common_class_names` will | |
have `lower_thres` as a threshold. If None, all classes will have | |
`thres` as a threshold. This is helpful for model trained on | |
highly imbalanced dataset. | |
mode (str): Supported modes are {"top-k", "thres"}. | |
This is used for choosing predictions for visualization. | |
""" | |
assert mode in ["top-k", "thres"], "Mode {} is not supported.".format(mode) | |
self.mode = mode | |
self.num_classes = num_classes | |
self.class_names = class_names | |
self.top_k = top_k | |
self.thres = thres | |
self.lower_thres = lower_thres | |
if mode == "thres": | |
self._get_thres_array(common_class_names=common_class_names) | |
self.color_map = plt.get_cmap(colormap) | |
def _get_color(self, class_id: int) -> List[float]: | |
""" | |
Get color for a class id. | |
Args: | |
class_id (int): class id. | |
""" | |
return self.color_map(class_id / self.num_classes)[:3] | |
def draw_one_frame( | |
self, | |
frame: Union[torch.Tensor, np.ndarray], | |
preds: Union[torch.Tensor, List[float]], | |
bboxes: Optional[torch.Tensor] = None, | |
alpha: float = 0.5, | |
text_alpha: float = 0.7, | |
ground_truth: bool = False, | |
) -> np.ndarray: | |
""" | |
Draw labels and bouding boxes for one image. By default, predicted | |
labels are drawn in the top left corner of the image or corresponding | |
bounding boxes. For ground truth labels (setting True for ground_truth flag), | |
labels will be drawn in the bottom left corner. | |
Args: | |
frame (array-like): a tensor or numpy array of shape (H, W, C), | |
where H and W correspond to | |
the height and width of the image respectively. C is the number of | |
color channels. The image is required to be in RGB format since that | |
is a requirement of the Matplotlib library. The image is also expected | |
to be in the range [0, 255]. | |
preds (tensor or list): If ground_truth is False, provide a float tensor of | |
shape (num_boxes, num_classes) that contains all of the confidence | |
scores of the model. For recognition task, input shape can be (num_classes,). | |
To plot true label (ground_truth is True), preds is a list contains int32 | |
of the shape (num_boxes, true_class_ids) or (true_class_ids,). | |
bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates | |
of the bounding boxes. | |
alpha (Optional[float]): transparency level of the bounding boxes. | |
text_alpha (Optional[float]): transparency level of the box wrapped around | |
text labels. | |
ground_truth (bool): whether the prodived bounding boxes are ground-truth. | |
Returns: | |
An image with bounding box annotations and corresponding bbox | |
labels plotted on it. | |
""" | |
if isinstance(preds, torch.Tensor): | |
if preds.ndim == 1: | |
preds = preds.unsqueeze(0) | |
n_instances = preds.shape[0] | |
elif isinstance(preds, list): | |
n_instances = len(preds) | |
else: | |
# logger.error("Unsupported type of prediction input.") | |
return | |
if ground_truth: | |
top_scores, top_classes = [None] * n_instances, preds | |
elif self.mode == "top-k": | |
top_scores, top_classes = torch.topk(preds, k=self.top_k) | |
top_scores, top_classes = top_scores.tolist(), top_classes.tolist() | |
elif self.mode == "thres": | |
top_scores, top_classes = [], [] | |
for pred in preds: | |
mask = pred >= self.thres | |
top_scores.append(pred[mask].tolist()) | |
top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist() | |
top_classes.append(top_class) | |
# Create labels top k predicted classes with their scores. | |
text_labels = [] | |
for i in range(n_instances): | |
text_labels.append( | |
_create_text_labels( | |
top_classes[i], | |
top_scores[i], | |
self.class_names, | |
ground_truth=ground_truth, | |
) | |
) | |
frame_visualizer = ImgVisualizer(frame, meta=None) | |
font_size = min(max(np.sqrt(frame.shape[0] * frame.shape[1]) // 25, 5), 9) | |
top_corner = not ground_truth | |
if bboxes is not None: | |
assert len(preds) == len( | |
bboxes | |
), "Encounter {} predictions and {} bounding boxes".format( | |
len(preds), len(bboxes) | |
) | |
for i, box in enumerate(bboxes): | |
text = text_labels[i] | |
pred_class = top_classes[i] | |
colors = [self._get_color(pred) for pred in pred_class] | |
box_color = "r" if ground_truth else "g" | |
line_style = "--" if ground_truth else "-." | |
frame_visualizer.draw_box( | |
box, | |
alpha=alpha, | |
edge_color=box_color, | |
line_style=line_style, | |
) | |
frame_visualizer.draw_multiple_text( | |
text, | |
box, | |
top_corner=top_corner, | |
font_size=font_size, | |
box_facecolors=colors, | |
alpha=text_alpha, | |
) | |
else: | |
text = text_labels[0] | |
pred_class = top_classes[0] | |
colors = [self._get_color(pred) for pred in pred_class] | |
frame_visualizer.draw_multiple_text( | |
text, | |
torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]), | |
top_corner=top_corner, | |
font_size=font_size, | |
box_facecolors=colors, | |
alpha=text_alpha, | |
) | |
return frame_visualizer.output.get_image() | |
def draw_clip_range( | |
self, | |
frames: Union[torch.Tensor, np.ndarray], | |
preds: Union[torch.Tensor, List[float]], | |
bboxes: Optional[torch.Tensor] = None, | |
text_alpha: float = 0.5, | |
ground_truth: bool = False, | |
keyframe_idx: Optional[int] = None, | |
draw_range: Optional[List[int]] = None, | |
repeat_frame: int = 1, | |
) -> List[np.ndarray]: | |
""" | |
Draw predicted labels or ground truth classes to clip. | |
Draw bouding boxes to clip if bboxes is provided. Boxes will gradually | |
fade in and out the clip, centered around the clip's central frame, | |
within the provided `draw_range`. | |
Args: | |
frames (array-like): video data in the shape (T, H, W, C). | |
preds (tensor): a tensor of shape (num_boxes, num_classes) that | |
contains all of the confidence scores of the model. For recognition | |
task or for ground_truth labels, input shape can be (num_classes,). | |
bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates | |
of the bounding boxes. | |
text_alpha (float): transparency label of the box wrapped around text labels. | |
ground_truth (bool): whether the prodived bounding boxes are ground-truth. | |
keyframe_idx (int): the index of keyframe in the clip. | |
draw_range (Optional[list[ints]): only draw frames in range | |
[start_idx, end_idx] inclusively in the clip. If None, draw on | |
the entire clip. | |
repeat_frame (int): repeat each frame in draw_range for `repeat_frame` | |
time for slow-motion effect. | |
Returns: | |
A list of frames with bounding box annotations and corresponding | |
bbox labels ploted on them. | |
""" | |
if draw_range is None: | |
draw_range = [0, len(frames) - 1] | |
if draw_range is not None: | |
draw_range[0] = max(0, draw_range[0]) | |
left_frames = frames[: draw_range[0]] | |
right_frames = frames[draw_range[1] + 1 :] | |
draw_frames = frames[draw_range[0] : draw_range[1] + 1] | |
if keyframe_idx is None: | |
keyframe_idx = len(frames) // 2 | |
img_ls = ( | |
list(left_frames) | |
+ self.draw_clip( | |
draw_frames, | |
preds, | |
bboxes=bboxes, | |
text_alpha=text_alpha, | |
ground_truth=ground_truth, | |
keyframe_idx=keyframe_idx - draw_range[0], | |
repeat_frame=repeat_frame, | |
) | |
+ list(right_frames) | |
) | |
return img_ls | |
def draw_clip( | |
self, | |
frames: Union[torch.Tensor, np.ndarray], | |
preds: Union[torch.Tensor, List[float]], | |
bboxes: Optional[torch.Tensor] = None, | |
text_alpha: float = 0.5, | |
ground_truth: bool = False, | |
keyframe_idx: Optional[int] = None, | |
repeat_frame: int = 1, | |
) -> List[np.ndarray]: | |
""" | |
Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip | |
if bboxes is provided. Boxes will gradually fade in and out the clip, centered | |
around the clip's central frame. | |
Args: | |
frames (array-like): video data in the shape (T, H, W, C). | |
preds (tensor): a tensor of shape (num_boxes, num_classes) that contains | |
all of the confidence scores of the model. For recognition task or for | |
ground_truth labels, input shape can be (num_classes,). | |
bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates | |
of the bounding boxes. | |
text_alpha (float): transparency label of the box wrapped around text labels. | |
ground_truth (bool): whether the prodived bounding boxes are ground-truth. | |
keyframe_idx (int): the index of keyframe in the clip. | |
repeat_frame (int): repeat each frame in draw_range for `repeat_frame` | |
time for slow-motion effect. | |
Returns: | |
A list of frames with bounding box annotations and corresponding | |
bbox labels plotted on them. | |
""" | |
assert repeat_frame >= 1, "`repeat_frame` must be a positive integer." | |
repeated_seq = range(0, len(frames)) | |
repeated_seq = list( | |
itertools.chain.from_iterable( | |
itertools.repeat(x, repeat_frame) for x in repeated_seq | |
) | |
) | |
frames, adjusted = self._adjust_frames_type(frames) | |
if keyframe_idx is None: | |
half_left = len(repeated_seq) // 2 | |
half_right = (len(repeated_seq) + 1) // 2 | |
else: | |
mid = int((keyframe_idx / len(frames)) * len(repeated_seq)) | |
half_left = mid | |
half_right = len(repeated_seq) - mid | |
alpha_ls = np.concatenate( | |
[ | |
np.linspace(0, 1, num=half_left), | |
np.linspace(1, 0, num=half_right), | |
] | |
) | |
text_alpha = text_alpha | |
frames = frames[repeated_seq] | |
img_ls = [] | |
for alpha, frame in zip(alpha_ls, frames): | |
draw_img = self.draw_one_frame( | |
frame, | |
preds, | |
bboxes, | |
alpha=alpha, | |
text_alpha=text_alpha, | |
ground_truth=ground_truth, | |
) | |
if adjusted: | |
draw_img = draw_img.astype("float32") / 255 | |
img_ls.append(draw_img) | |
return img_ls | |
def _adjust_frames_type( | |
self, frames: torch.Tensor | |
) -> Tuple[List[np.ndarray], bool]: | |
""" | |
Modify video data to have dtype of uint8 and values range in [0, 255]. | |
Args: | |
frames (array-like): 4D array of shape (T, H, W, C). | |
Returns: | |
frames (list of frames): list of frames in range [0, 1]. | |
adjusted (bool): whether the original frames need adjusted. | |
""" | |
assert ( | |
frames is not None and len(frames) != 0 | |
), "Frames does not contain any values" | |
frames = np.array(frames) | |
assert np.array(frames).ndim == 4, "Frames must have 4 dimensions" | |
adjusted = False | |
if frames.dtype in [np.float32, np.float64]: | |
frames *= 255 | |
frames = frames.astype(np.uint8) | |
adjusted = True | |
return frames, adjusted | |
def _get_thres_array(self, common_class_names: Optional[List[str]] = None) -> None: | |
""" | |
Compute a thresholds array for all classes based on `self.thes` and `self.lower_thres`. | |
Args: | |
common_class_names (Optional[list of str]): a list of common class names. | |
""" | |
common_class_ids = [] | |
if common_class_names is not None: | |
common_classes = set(common_class_names) | |
for key, name in self.class_names.items(): | |
if name in common_classes: | |
common_class_ids.append(key) | |
else: | |
common_class_ids = list(range(self.num_classes)) | |
thres_array = np.full(shape=(self.num_classes,), fill_value=self.lower_thres) | |
thres_array[common_class_ids] = self.thres | |
self.thres = torch.from_numpy(thres_array) |