# -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. import csv import logging import numpy as np from typing import Any, Callable, Dict, List, Optional, Union import av import torch from torch.utils.data.dataset import Dataset from detectron2.utils.file_io import PathManager from ..utils import maybe_prepend_base_path from .frame_selector import FrameSelector, FrameTsList FrameList = List[av.frame.Frame] # pyre-ignore[16] FrameTransform = Callable[[torch.Tensor], torch.Tensor] def list_keyframes(video_fpath: str, video_stream_idx: int = 0) -> FrameTsList: """ Traverses all keyframes of a video file. Returns a list of keyframe timestamps. Timestamps are counts in timebase units. Args: video_fpath (str): Video file path video_stream_idx (int): Video stream index (default: 0) Returns: List[int]: list of keyframe timestaps (timestamp is a count in timebase units) """ try: with PathManager.open(video_fpath, "rb") as io: container = av.open(io, mode="r") stream = container.streams.video[video_stream_idx] keyframes = [] pts = -1 # Note: even though we request forward seeks for keyframes, sometimes # a keyframe in backwards direction is returned. We introduce tolerance # as a max count of ignored backward seeks tolerance_backward_seeks = 2 while True: try: container.seek(pts + 1, backward=False, any_frame=False, stream=stream) except av.AVError as e: # the exception occurs when the video length is exceeded, # we then return whatever data we've already collected logger = logging.getLogger(__name__) logger.debug( f"List keyframes: Error seeking video file {video_fpath}, " f"video stream {video_stream_idx}, pts {pts + 1}, AV error: {e}" ) return keyframes except OSError as e: logger = logging.getLogger(__name__) logger.warning( f"List keyframes: Error seeking video file {video_fpath}, " f"video stream {video_stream_idx}, pts {pts + 1}, OS error: {e}" ) return [] packet = next(container.demux(video=video_stream_idx)) if packet.pts is not None and packet.pts <= pts: logger = logging.getLogger(__name__) logger.warning( f"Video file {video_fpath}, stream {video_stream_idx}: " f"bad seek for packet {pts + 1} (got packet {packet.pts}), " f"tolerance {tolerance_backward_seeks}." ) tolerance_backward_seeks -= 1 if tolerance_backward_seeks == 0: return [] pts += 1 continue tolerance_backward_seeks = 2 pts = packet.pts if pts is None: return keyframes if packet.is_keyframe: keyframes.append(pts) return keyframes except OSError as e: logger = logging.getLogger(__name__) logger.warning( f"List keyframes: Error opening video file container {video_fpath}, " f"OS error: {e}" ) except RuntimeError as e: logger = logging.getLogger(__name__) logger.warning( f"List keyframes: Error opening video file container {video_fpath}, " f"Runtime error: {e}" ) return [] def read_keyframes( video_fpath: str, keyframes: FrameTsList, video_stream_idx: int = 0 ) -> FrameList: # pyre-ignore[11] """ Reads keyframe data from a video file. Args: video_fpath (str): Video file path keyframes (List[int]): List of keyframe timestamps (as counts in timebase units to be used in container seek operations) video_stream_idx (int): Video stream index (default: 0) Returns: List[Frame]: list of frames that correspond to the specified timestamps """ try: with PathManager.open(video_fpath, "rb") as io: container = av.open(io) stream = container.streams.video[video_stream_idx] frames = [] for pts in keyframes: try: container.seek(pts, any_frame=False, stream=stream) frame = next(container.decode(video=0)) frames.append(frame) except av.AVError as e: logger = logging.getLogger(__name__) logger.warning( f"Read keyframes: Error seeking video file {video_fpath}, " f"video stream {video_stream_idx}, pts {pts}, AV error: {e}" ) container.close() return frames except OSError as e: logger = logging.getLogger(__name__) logger.warning( f"Read keyframes: Error seeking video file {video_fpath}, " f"video stream {video_stream_idx}, pts {pts}, OS error: {e}" ) container.close() return frames except StopIteration: logger = logging.getLogger(__name__) logger.warning( f"Read keyframes: Error decoding frame from {video_fpath}, " f"video stream {video_stream_idx}, pts {pts}" ) container.close() return frames container.close() return frames except OSError as e: logger = logging.getLogger(__name__) logger.warning( f"Read keyframes: Error opening video file container {video_fpath}, OS error: {e}" ) except RuntimeError as e: logger = logging.getLogger(__name__) logger.warning( f"Read keyframes: Error opening video file container {video_fpath}, Runtime error: {e}" ) return [] def video_list_from_file(video_list_fpath: str, base_path: Optional[str] = None): """ Create a list of paths to video files from a text file. Args: video_list_fpath (str): path to a plain text file with the list of videos base_path (str): base path for entries from the video list (default: None) """ video_list = [] with PathManager.open(video_list_fpath, "r") as io: for line in io: video_list.append(maybe_prepend_base_path(base_path, str(line.strip()))) return video_list def read_keyframe_helper_data(fpath: str): """ Read keyframe data from a file in CSV format: the header should contain "video_id" and "keyframes" fields. Value specifications are: video_id: int keyframes: list(int) Example of contents: video_id,keyframes 2,"[1,11,21,31,41,51,61,71,81]" Args: fpath (str): File containing keyframe data Return: video_id_to_keyframes (dict: int -> list(int)): for a given video ID it contains a list of keyframes for that video """ video_id_to_keyframes = {} try: with PathManager.open(fpath, "r") as io: csv_reader = csv.reader(io) header = next(csv_reader) video_id_idx = header.index("video_id") keyframes_idx = header.index("keyframes") for row in csv_reader: video_id = int(row[video_id_idx]) assert ( video_id not in video_id_to_keyframes ), f"Duplicate keyframes entry for video {fpath}" video_id_to_keyframes[video_id] = ( [int(v) for v in row[keyframes_idx][1:-1].split(",")] if len(row[keyframes_idx]) > 2 else [] ) except Exception as e: logger = logging.getLogger(__name__) logger.warning(f"Error reading keyframe helper data from {fpath}: {e}") return video_id_to_keyframes class VideoKeyframeDataset(Dataset): """ Dataset that provides keyframes for a set of videos. """ _EMPTY_FRAMES = torch.empty((0, 3, 1, 1)) def __init__( self, video_list: List[str], category_list: Union[str, List[str], None] = None, frame_selector: Optional[FrameSelector] = None, transform: Optional[FrameTransform] = None, keyframe_helper_fpath: Optional[str] = None, ): """ Dataset constructor Args: video_list (List[str]): list of paths to video files category_list (Union[str, List[str], None]): list of animal categories for each video file. If it is a string, or None, this applies to all videos frame_selector (Callable: KeyFrameList -> KeyFrameList): selects keyframes to process, keyframes are given by packet timestamps in timebase counts. If None, all keyframes are selected (default: None) transform (Callable: torch.Tensor -> torch.Tensor): transforms a batch of RGB images (tensors of size [B, 3, H, W]), returns a tensor of the same size. If None, no transform is applied (default: None) """ if type(category_list) == list: self.category_list = category_list else: self.category_list = [category_list] * len(video_list) assert len(video_list) == len( self.category_list ), "length of video and category lists must be equal" self.video_list = video_list self.frame_selector = frame_selector self.transform = transform self.keyframe_helper_data = ( read_keyframe_helper_data(keyframe_helper_fpath) if keyframe_helper_fpath is not None else None ) def __getitem__(self, idx: int) -> Dict[str, Any]: """ Gets selected keyframes from a given video Args: idx (int): video index in the video list file Returns: A dictionary containing two keys: images (torch.Tensor): tensor of size [N, H, W, 3] or of size defined by the transform that contains keyframes data categories (List[str]): categories of the frames """ categories = [self.category_list[idx]] fpath = self.video_list[idx] keyframes = ( list_keyframes(fpath) if self.keyframe_helper_data is None or idx not in self.keyframe_helper_data else self.keyframe_helper_data[idx] ) transform = self.transform frame_selector = self.frame_selector if not keyframes: return {"images": self._EMPTY_FRAMES, "categories": []} if frame_selector is not None: keyframes = frame_selector(keyframes) frames = read_keyframes(fpath, keyframes) if not frames: return {"images": self._EMPTY_FRAMES, "categories": []} frames = np.stack([frame.to_rgb().to_ndarray() for frame in frames]) frames = torch.as_tensor(frames, device=torch.device("cpu")) frames = frames[..., [2, 1, 0]] # RGB -> BGR frames = frames.permute(0, 3, 1, 2).float() # NHWC -> NCHW if transform is not None: frames = transform(frames) return {"images": frames, "categories": categories} def __len__(self): return len(self.video_list)