rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
8.78 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from addict import Dict
class BaseTracker(metaclass=ABCMeta):
"""Base tracker model.
Args:
momentums (dict[str:float], optional): Momentums to update the buffers.
The `str` indicates the name of the buffer while the `float`
indicates the momentum. Defaults to None.
num_frames_retain (int, optional). If a track is disappeared more than
`num_frames_retain` frames, it will be deleted in the memo.
Defaults to 10.
"""
def __init__(self,
momentums: Optional[dict] = None,
num_frames_retain: int = 10) -> None:
super().__init__()
if momentums is not None:
assert isinstance(momentums, dict), 'momentums must be a dict'
self.momentums = momentums
self.num_frames_retain = num_frames_retain
self.reset()
def reset(self) -> None:
"""Reset the buffer of the tracker."""
self.num_tracks = 0
self.tracks = dict()
@property
def empty(self) -> bool:
"""Whether the buffer is empty or not."""
return False if self.tracks else True
@property
def ids(self) -> List[dict]:
"""All ids in the tracker."""
return list(self.tracks.keys())
@property
def with_reid(self) -> bool:
"""bool: whether the framework has a reid model"""
return hasattr(self, 'reid') and self.reid is not None
def update(self, **kwargs) -> None:
"""Update the tracker.
Args:
kwargs (dict[str: Tensor | int]): The `str` indicates the
name of the input variable. `ids` and `frame_ids` are
obligatory in the keys.
"""
memo_items = [k for k, v in kwargs.items() if v is not None]
rm_items = [k for k in kwargs.keys() if k not in memo_items]
for item in rm_items:
kwargs.pop(item)
if not hasattr(self, 'memo_items'):
self.memo_items = memo_items
else:
assert memo_items == self.memo_items
assert 'ids' in memo_items
num_objs = len(kwargs['ids'])
id_indice = memo_items.index('ids')
assert 'frame_ids' in memo_items
frame_id = int(kwargs['frame_ids'])
if isinstance(kwargs['frame_ids'], int):
kwargs['frame_ids'] = torch.tensor([kwargs['frame_ids']] *
num_objs)
# cur_frame_id = int(kwargs['frame_ids'][0])
for k, v in kwargs.items():
if len(v) != num_objs:
raise ValueError('kwargs value must both equal')
for obj in zip(*kwargs.values()):
id = int(obj[id_indice])
if id in self.tracks:
self.update_track(id, obj)
else:
self.init_track(id, obj)
self.pop_invalid_tracks(frame_id)
def pop_invalid_tracks(self, frame_id: int) -> None:
"""Pop out invalid tracks."""
invalid_ids = []
for k, v in self.tracks.items():
if frame_id - v['frame_ids'][-1] >= self.num_frames_retain:
invalid_ids.append(k)
for invalid_id in invalid_ids:
self.tracks.pop(invalid_id)
def update_track(self, id: int, obj: Tuple[torch.Tensor]):
"""Update a track."""
for k, v in zip(self.memo_items, obj):
v = v[None]
if self.momentums is not None and k in self.momentums:
m = self.momentums[k]
self.tracks[id][k] = (1 - m) * self.tracks[id][k] + m * v
else:
self.tracks[id][k].append(v)
def init_track(self, id: int, obj: Tuple[torch.Tensor]):
"""Initialize a track."""
self.tracks[id] = Dict()
for k, v in zip(self.memo_items, obj):
v = v[None]
if self.momentums is not None and k in self.momentums:
self.tracks[id][k] = v
else:
self.tracks[id][k] = [v]
@property
def memo(self) -> dict:
"""Return all buffers in the tracker."""
outs = Dict()
for k in self.memo_items:
outs[k] = []
for id, objs in self.tracks.items():
for k, v in objs.items():
if k not in outs:
continue
if self.momentums is not None and k in self.momentums:
v = v
else:
v = v[-1]
outs[k].append(v)
for k, v in outs.items():
outs[k] = torch.cat(v, dim=0)
return outs
def get(self,
item: str,
ids: Optional[list] = None,
num_samples: Optional[int] = None,
behavior: Optional[str] = None) -> torch.Tensor:
"""Get the buffer of a specific item.
Args:
item (str): The demanded item.
ids (list[int], optional): The demanded ids. Defaults to None.
num_samples (int, optional): Number of samples to calculate the
results. Defaults to None.
behavior (str, optional): Behavior to calculate the results.
Options are `mean` | None. Defaults to None.
Returns:
Tensor: The results of the demanded item.
"""
if ids is None:
ids = self.ids
outs = []
for id in ids:
out = self.tracks[id][item]
if isinstance(out, list):
if num_samples is not None:
out = out[-num_samples:]
out = torch.cat(out, dim=0)
if behavior == 'mean':
out = out.mean(dim=0, keepdim=True)
elif behavior is None:
out = out[None]
else:
raise NotImplementedError()
else:
out = out[-1]
outs.append(out)
return torch.cat(outs, dim=0)
@abstractmethod
def track(self, *args, **kwargs):
"""Tracking forward function."""
pass
def crop_imgs(self,
img: torch.Tensor,
meta_info: dict,
bboxes: torch.Tensor,
rescale: bool = False) -> torch.Tensor:
"""Crop the images according to some bounding boxes. Typically for re-
identification sub-module.
Args:
img (Tensor): of shape (T, C, H, W) encoding input image.
Typically these should be mean centered and std scaled.
meta_info (dict): image information dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
bboxes (Tensor): of shape (N, 4) or (N, 5).
rescale (bool, optional): If True, the bounding boxes should be
rescaled to fit the scale of the image. Defaults to False.
Returns:
Tensor: Image tensor of shape (T, C, H, W).
"""
h, w = meta_info['img_shape']
img = img[:, :, :h, :w]
if rescale:
factor_x, factor_y = meta_info['scale_factor']
bboxes[:, :4] *= torch.tensor(
[factor_x, factor_y, factor_x, factor_y]).to(bboxes.device)
bboxes[:, 0] = torch.clamp(bboxes[:, 0], min=0, max=w - 1)
bboxes[:, 1] = torch.clamp(bboxes[:, 1], min=0, max=h - 1)
bboxes[:, 2] = torch.clamp(bboxes[:, 2], min=1, max=w)
bboxes[:, 3] = torch.clamp(bboxes[:, 3], min=1, max=h)
crop_imgs = []
for bbox in bboxes:
x1, y1, x2, y2 = map(int, bbox)
if x2 <= x1:
x2 = x1 + 1
if y2 <= y1:
y2 = y1 + 1
crop_img = img[:, :, y1:y2, x1:x2]
if self.reid.get('img_scale', False):
crop_img = F.interpolate(
crop_img,
size=self.reid['img_scale'],
mode='bilinear',
align_corners=False)
crop_imgs.append(crop_img)
if len(crop_imgs) > 0:
return torch.cat(crop_imgs, dim=0)
elif self.reid.get('img_scale', False):
_h, _w = self.reid['img_scale']
return img.new_zeros((0, 3, _h, _w))
else:
return img.new_zeros((0, 3, h, w))