Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
from collections import defaultdict | |
from typing import Dict, List, Optional, Union | |
from mmcv.transforms import BaseTransform | |
from mmdet.registry import TRANSFORMS | |
class BaseFrameSample(BaseTransform): | |
"""Directly get the key frame, no reference frames. | |
Args: | |
collect_video_keys (list[str]): The keys of video info to be | |
collected. | |
""" | |
def __init__(self, | |
collect_video_keys: List[str] = ['video_id', 'video_length']): | |
self.collect_video_keys = collect_video_keys | |
def prepare_data(self, video_infos: dict, | |
sampled_inds: List[int]) -> Dict[str, List]: | |
"""Prepare data for the subsequent pipeline. | |
Args: | |
video_infos (dict): The whole video information. | |
sampled_inds (list[int]): The sampled frame indices. | |
Returns: | |
dict: The processed data information. | |
""" | |
frames_anns = video_infos['images'] | |
final_data_info = defaultdict(list) | |
# for data in frames_anns: | |
for index in sampled_inds: | |
data = frames_anns[index] | |
# copy the info in video-level into img-level | |
for key in self.collect_video_keys: | |
if key == 'video_length': | |
data['ori_video_length'] = video_infos[key] | |
data['video_length'] = len(sampled_inds) | |
else: | |
data[key] = video_infos[key] | |
# Collate data_list (list of dict to dict of list) | |
for key, value in data.items(): | |
final_data_info[key].append(value) | |
return final_data_info | |
def transform(self, video_infos: dict) -> Optional[Dict[str, List]]: | |
"""Transform the video information. | |
Args: | |
video_infos (dict): The whole video information. | |
Returns: | |
dict: The data information of the key frames. | |
""" | |
if 'key_frame_id' in video_infos: | |
key_frame_id = video_infos['key_frame_id'] | |
assert isinstance(video_infos['key_frame_id'], int) | |
else: | |
key_frame_id = random.sample( | |
list(range(video_infos['video_length'])), 1)[0] | |
results = self.prepare_data(video_infos, [key_frame_id]) | |
return results | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
repr_str += f'(collect_video_keys={self.collect_video_keys})' | |
return repr_str | |
class UniformRefFrameSample(BaseFrameSample): | |
"""Uniformly sample reference frames. | |
Args: | |
num_ref_imgs (int): Number of reference frames to be sampled. | |
frame_range (int | list[int]): Range of frames to be sampled around | |
key frame. If int, the range is [-frame_range, frame_range]. | |
Defaults to 10. | |
filter_key_img (bool): Whether to filter the key frame when | |
sampling reference frames. Defaults to True. | |
collect_video_keys (list[str]): The keys of video info to be | |
collected. | |
""" | |
def __init__(self, | |
num_ref_imgs: int = 1, | |
frame_range: Union[int, List[int]] = 10, | |
filter_key_img: bool = True, | |
collect_video_keys: List[str] = ['video_id', 'video_length']): | |
self.num_ref_imgs = num_ref_imgs | |
self.filter_key_img = filter_key_img | |
if isinstance(frame_range, int): | |
assert frame_range >= 0, 'frame_range can not be a negative value.' | |
frame_range = [-frame_range, frame_range] | |
elif isinstance(frame_range, list): | |
assert len(frame_range) == 2, 'The length must be 2.' | |
assert frame_range[0] <= 0 and frame_range[1] >= 0 | |
for i in frame_range: | |
assert isinstance(i, int), 'Each element must be int.' | |
else: | |
raise TypeError('The type of frame_range must be int or list.') | |
self.frame_range = frame_range | |
super().__init__(collect_video_keys=collect_video_keys) | |
def sampling_frames(self, video_length: int, key_frame_id: int): | |
"""Sampling frames. | |
Args: | |
video_length (int): The length of the video. | |
key_frame_id (int): The key frame id. | |
Returns: | |
list[int]: The sampled frame indices. | |
""" | |
if video_length > 1: | |
left = max(0, key_frame_id + self.frame_range[0]) | |
right = min(key_frame_id + self.frame_range[1], video_length - 1) | |
frame_ids = list(range(0, video_length)) | |
valid_ids = frame_ids[left:right + 1] | |
if self.filter_key_img and key_frame_id in valid_ids: | |
valid_ids.remove(key_frame_id) | |
assert len( | |
valid_ids | |
) > 0, 'After filtering key frame, there are no valid frames' | |
if len(valid_ids) < self.num_ref_imgs: | |
valid_ids = valid_ids * self.num_ref_imgs | |
ref_frame_ids = random.sample(valid_ids, self.num_ref_imgs) | |
else: | |
ref_frame_ids = [key_frame_id] * self.num_ref_imgs | |
sampled_frames_ids = [key_frame_id] + ref_frame_ids | |
sampled_frames_ids = sorted(sampled_frames_ids) | |
key_frames_ind = sampled_frames_ids.index(key_frame_id) | |
key_frame_flags = [False] * len(sampled_frames_ids) | |
key_frame_flags[key_frames_ind] = True | |
return sampled_frames_ids, key_frame_flags | |
def transform(self, video_infos: dict) -> Optional[Dict[str, List]]: | |
"""Transform the video information. | |
Args: | |
video_infos (dict): The whole video information. | |
Returns: | |
dict: The data information of the sampled frames. | |
""" | |
if 'key_frame_id' in video_infos: | |
key_frame_id = video_infos['key_frame_id'] | |
assert isinstance(video_infos['key_frame_id'], int) | |
else: | |
key_frame_id = random.sample( | |
list(range(video_infos['video_length'])), 1)[0] | |
(sampled_frames_ids, key_frame_flags) = self.sampling_frames( | |
video_infos['video_length'], key_frame_id=key_frame_id) | |
results = self.prepare_data(video_infos, sampled_frames_ids) | |
results['key_frame_flags'] = key_frame_flags | |
return results | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
repr_str += f'(num_ref_imgs={self.num_ref_imgs}, ' | |
repr_str += f'frame_range={self.frame_range}, ' | |
repr_str += f'filter_key_img={self.filter_key_img}, ' | |
repr_str += f'collect_video_keys={self.collect_video_keys})' | |
return repr_str | |