Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import List, Tuple | |
import torch | |
import torch.nn.functional as F | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.structures import TrackDataSample | |
from mmdet.structures.bbox import bbox_overlaps | |
from .base_tracker import BaseTracker | |
class QuasiDenseTracker(BaseTracker): | |
"""Tracker for Quasi-Dense Tracking. | |
Args: | |
init_score_thr (float): The cls_score threshold to | |
initialize a new tracklet. Defaults to 0.8. | |
obj_score_thr (float): The cls_score threshold to | |
update a tracked tracklet. Defaults to 0.5. | |
match_score_thr (float): The match threshold. Defaults to 0.5. | |
memo_tracklet_frames (int): The most frames in a tracklet memory. | |
Defaults to 10. | |
memo_backdrop_frames (int): The most frames in the backdrops. | |
Defaults to 1. | |
memo_momentum (float): The momentum value for embeds updating. | |
Defaults to 0.8. | |
nms_conf_thr (float): The nms threshold for confidence. | |
Defaults to 0.5. | |
nms_backdrop_iou_thr (float): The nms threshold for backdrop IoU. | |
Defaults to 0.3. | |
nms_class_iou_thr (float): The nms threshold for class IoU. | |
Defaults to 0.7. | |
with_cats (bool): Whether to track with the same category. | |
Defaults to True. | |
match_metric (str): The match metric. Defaults to 'bisoftmax'. | |
""" | |
def __init__(self, | |
init_score_thr: float = 0.8, | |
obj_score_thr: float = 0.5, | |
match_score_thr: float = 0.5, | |
memo_tracklet_frames: int = 10, | |
memo_backdrop_frames: int = 1, | |
memo_momentum: float = 0.8, | |
nms_conf_thr: float = 0.5, | |
nms_backdrop_iou_thr: float = 0.3, | |
nms_class_iou_thr: float = 0.7, | |
with_cats: bool = True, | |
match_metric: str = 'bisoftmax', | |
**kwargs): | |
super().__init__(**kwargs) | |
assert 0 <= memo_momentum <= 1.0 | |
assert memo_tracklet_frames >= 0 | |
assert memo_backdrop_frames >= 0 | |
self.init_score_thr = init_score_thr | |
self.obj_score_thr = obj_score_thr | |
self.match_score_thr = match_score_thr | |
self.memo_tracklet_frames = memo_tracklet_frames | |
self.memo_backdrop_frames = memo_backdrop_frames | |
self.memo_momentum = memo_momentum | |
self.nms_conf_thr = nms_conf_thr | |
self.nms_backdrop_iou_thr = nms_backdrop_iou_thr | |
self.nms_class_iou_thr = nms_class_iou_thr | |
self.with_cats = with_cats | |
assert match_metric in ['bisoftmax', 'softmax', 'cosine'] | |
self.match_metric = match_metric | |
self.num_tracks = 0 | |
self.tracks = dict() | |
self.backdrops = [] | |
def reset(self): | |
"""Reset the buffer of the tracker.""" | |
self.num_tracks = 0 | |
self.tracks = dict() | |
self.backdrops = [] | |
def update(self, ids: Tensor, bboxes: Tensor, embeds: Tensor, | |
labels: Tensor, scores: Tensor, frame_id: int) -> None: | |
"""Tracking forward function. | |
Args: | |
ids (Tensor): of shape(N, ). | |
bboxes (Tensor): of shape (N, 5). | |
embeds (Tensor): of shape (N, 256). | |
labels (Tensor): of shape (N, ). | |
scores (Tensor): of shape (N, ). | |
frame_id (int): The id of current frame, 0-index. | |
""" | |
tracklet_inds = ids > -1 | |
for id, bbox, embed, label, score in zip(ids[tracklet_inds], | |
bboxes[tracklet_inds], | |
embeds[tracklet_inds], | |
labels[tracklet_inds], | |
scores[tracklet_inds]): | |
id = int(id) | |
# update the tracked ones and initialize new tracks | |
if id in self.tracks.keys(): | |
velocity = (bbox - self.tracks[id]['bbox']) / ( | |
frame_id - self.tracks[id]['last_frame']) | |
self.tracks[id]['bbox'] = bbox | |
self.tracks[id]['embed'] = ( | |
1 - self.memo_momentum | |
) * self.tracks[id]['embed'] + self.memo_momentum * embed | |
self.tracks[id]['last_frame'] = frame_id | |
self.tracks[id]['label'] = label | |
self.tracks[id]['score'] = score | |
self.tracks[id]['velocity'] = ( | |
self.tracks[id]['velocity'] * self.tracks[id]['acc_frame'] | |
+ velocity) / ( | |
self.tracks[id]['acc_frame'] + 1) | |
self.tracks[id]['acc_frame'] += 1 | |
else: | |
self.tracks[id] = dict( | |
bbox=bbox, | |
embed=embed, | |
label=label, | |
score=score, | |
last_frame=frame_id, | |
velocity=torch.zeros_like(bbox), | |
acc_frame=0) | |
# backdrop update according to IoU | |
backdrop_inds = torch.nonzero(ids == -1, as_tuple=False).squeeze(1) | |
ious = bbox_overlaps(bboxes[backdrop_inds], bboxes) | |
for i, ind in enumerate(backdrop_inds): | |
if (ious[i, :ind] > self.nms_backdrop_iou_thr).any(): | |
backdrop_inds[i] = -1 | |
backdrop_inds = backdrop_inds[backdrop_inds > -1] | |
# old backdrops would be removed at first | |
self.backdrops.insert( | |
0, | |
dict( | |
bboxes=bboxes[backdrop_inds], | |
embeds=embeds[backdrop_inds], | |
labels=labels[backdrop_inds])) | |
# pop memo | |
invalid_ids = [] | |
for k, v in self.tracks.items(): | |
if frame_id - v['last_frame'] >= self.memo_tracklet_frames: | |
invalid_ids.append(k) | |
for invalid_id in invalid_ids: | |
self.tracks.pop(invalid_id) | |
if len(self.backdrops) > self.memo_backdrop_frames: | |
self.backdrops.pop() | |
def memo(self) -> Tuple[Tensor, ...]: | |
"""Get tracks memory.""" | |
memo_embeds = [] | |
memo_ids = [] | |
memo_bboxes = [] | |
memo_labels = [] | |
# velocity of tracks | |
memo_vs = [] | |
# get tracks | |
for k, v in self.tracks.items(): | |
memo_bboxes.append(v['bbox'][None, :]) | |
memo_embeds.append(v['embed'][None, :]) | |
memo_ids.append(k) | |
memo_labels.append(v['label'].view(1, 1)) | |
memo_vs.append(v['velocity'][None, :]) | |
memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1) | |
# get backdrops | |
for backdrop in self.backdrops: | |
backdrop_ids = torch.full((1, backdrop['embeds'].size(0)), | |
-1, | |
dtype=torch.long) | |
backdrop_vs = torch.zeros_like(backdrop['bboxes']) | |
memo_bboxes.append(backdrop['bboxes']) | |
memo_embeds.append(backdrop['embeds']) | |
memo_ids = torch.cat([memo_ids, backdrop_ids], dim=1) | |
memo_labels.append(backdrop['labels'][:, None]) | |
memo_vs.append(backdrop_vs) | |
memo_bboxes = torch.cat(memo_bboxes, dim=0) | |
memo_embeds = torch.cat(memo_embeds, dim=0) | |
memo_labels = torch.cat(memo_labels, dim=0).squeeze(1) | |
memo_vs = torch.cat(memo_vs, dim=0) | |
return memo_bboxes, memo_labels, memo_embeds, memo_ids.squeeze( | |
0), memo_vs | |
def track(self, | |
model: torch.nn.Module, | |
img: torch.Tensor, | |
feats: List[torch.Tensor], | |
data_sample: TrackDataSample, | |
rescale=True, | |
**kwargs) -> InstanceData: | |
"""Tracking forward function. | |
Args: | |
model (nn.Module): MOT model. | |
img (Tensor): of shape (T, C, H, W) encoding input image. | |
Typically these should be mean centered and std scaled. | |
The T denotes the number of key images and usually is 1 in | |
QDTrack method. | |
feats (list[Tensor]): Multi level feature maps of `img`. | |
data_sample (:obj:`TrackDataSample`): The data sample. | |
It includes information such as `pred_instances`. | |
rescale (bool, optional): If True, the bounding boxes should be | |
rescaled to fit the original scale of the image. Defaults to | |
True. | |
Returns: | |
:obj:`InstanceData`: Tracking results of the input images. | |
Each InstanceData usually contains ``bboxes``, ``labels``, | |
``scores`` and ``instances_id``. | |
""" | |
metainfo = data_sample.metainfo | |
bboxes = data_sample.pred_instances.bboxes | |
labels = data_sample.pred_instances.labels | |
scores = data_sample.pred_instances.scores | |
frame_id = metainfo.get('frame_id', -1) | |
# create pred_track_instances | |
pred_track_instances = InstanceData() | |
# return zero bboxes if there is no track targets | |
if bboxes.shape[0] == 0: | |
ids = torch.zeros_like(labels) | |
pred_track_instances = data_sample.pred_instances.clone() | |
pred_track_instances.instances_id = ids | |
return pred_track_instances | |
# get track feats | |
rescaled_bboxes = bboxes.clone() | |
if rescale: | |
scale_factor = rescaled_bboxes.new_tensor( | |
metainfo['scale_factor']).repeat((1, 2)) | |
rescaled_bboxes = rescaled_bboxes * scale_factor | |
track_feats = model.track_head.predict(feats, [rescaled_bboxes]) | |
# sort according to the object_score | |
_, inds = scores.sort(descending=True) | |
bboxes = bboxes[inds] | |
scores = scores[inds] | |
labels = labels[inds] | |
embeds = track_feats[inds, :] | |
# duplicate removal for potential backdrops and cross classes | |
valids = bboxes.new_ones((bboxes.size(0))) | |
ious = bbox_overlaps(bboxes, bboxes) | |
for i in range(1, bboxes.size(0)): | |
thr = self.nms_backdrop_iou_thr if scores[ | |
i] < self.obj_score_thr else self.nms_class_iou_thr | |
if (ious[i, :i] > thr).any(): | |
valids[i] = 0 | |
valids = valids == 1 | |
bboxes = bboxes[valids] | |
scores = scores[valids] | |
labels = labels[valids] | |
embeds = embeds[valids, :] | |
# init ids container | |
ids = torch.full((bboxes.size(0), ), -1, dtype=torch.long) | |
# match if buffer is not empty | |
if bboxes.size(0) > 0 and not self.empty: | |
(memo_bboxes, memo_labels, memo_embeds, memo_ids, | |
memo_vs) = self.memo | |
if self.match_metric == 'bisoftmax': | |
feats = torch.mm(embeds, memo_embeds.t()) | |
d2t_scores = feats.softmax(dim=1) | |
t2d_scores = feats.softmax(dim=0) | |
match_scores = (d2t_scores + t2d_scores) / 2 | |
elif self.match_metric == 'softmax': | |
feats = torch.mm(embeds, memo_embeds.t()) | |
match_scores = feats.softmax(dim=1) | |
elif self.match_metric == 'cosine': | |
match_scores = torch.mm( | |
F.normalize(embeds, p=2, dim=1), | |
F.normalize(memo_embeds, p=2, dim=1).t()) | |
else: | |
raise NotImplementedError | |
# track with the same category | |
if self.with_cats: | |
cat_same = labels.view(-1, 1) == memo_labels.view(1, -1) | |
match_scores *= cat_same.float().to(match_scores.device) | |
# track according to match_scores | |
for i in range(bboxes.size(0)): | |
conf, memo_ind = torch.max(match_scores[i, :], dim=0) | |
id = memo_ids[memo_ind] | |
if conf > self.match_score_thr: | |
if id > -1: | |
# keep bboxes with high object score | |
# and remove background bboxes | |
if scores[i] > self.obj_score_thr: | |
ids[i] = id | |
match_scores[:i, memo_ind] = 0 | |
match_scores[i + 1:, memo_ind] = 0 | |
else: | |
if conf > self.nms_conf_thr: | |
ids[i] = -2 | |
# initialize new tracks | |
new_inds = (ids == -1) & (scores > self.init_score_thr).cpu() | |
num_news = new_inds.sum() | |
ids[new_inds] = torch.arange( | |
self.num_tracks, self.num_tracks + num_news, dtype=torch.long) | |
self.num_tracks += num_news | |
self.update(ids, bboxes, embeds, labels, scores, frame_id) | |
tracklet_inds = ids > -1 | |
# update pred_track_instances | |
pred_track_instances.bboxes = bboxes[tracklet_inds] | |
pred_track_instances.labels = labels[tracklet_inds] | |
pred_track_instances.scores = scores[tracklet_inds] | |
pred_track_instances.instances_id = ids[tracklet_inds] | |
return pred_track_instances | |