zcxu-eric's picture
add app
8aa9c9a
raw
history blame
6.1 kB
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Copyright 2022 ByteDance and/or its affiliates.
#
# Copyright (2022) PV3D Authors
#
# ByteDance, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
import av, gc
import torch
import warnings
import numpy as np
_CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 20
# remove warnings
av.logging.set_level(av.logging.ERROR)
class VideoReader():
"""
Simple wrapper around PyAV that exposes a few useful functions for
dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries.
Acknowledgement: Codes are borrowed from Bruno Korbar
"""
def __init__(self, video, num_frames=float("inf"), decode_lossy=False, audio_resample_rate=None, bi_frame=False):
"""
Arguments:
video_path (str): path or byte of the video to be loaded
"""
self.container = av.open(video)
self.num_frames = num_frames
self.bi_frame = bi_frame
self.resampler = None
if audio_resample_rate is not None:
self.resampler = av.AudioResampler(rate=audio_resample_rate)
if self.container.streams.video:
# enable multi-threaded video decoding
if decode_lossy:
warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning)
self.container.streams.video[0].thread_type = 'AUTO'
self.video_stream = self.container.streams.video[0]
else:
self.video_stream = None
self.fps = self._get_video_frame_rate()
def seek(self, pts, backward=True, any_frame=False):
stream = self.video_stream
self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream)
def _occasional_gc(self):
# there are a lot of reference cycles in PyAV, so need to manually call
# the garbage collector from time to time
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
gc.collect()
def _read_video(self, offset):
self._occasional_gc()
pts = self.container.duration * offset
time_ = pts / float(av.time_base)
self.container.seek(int(pts))
video_frames = []
count = 0
for _, frame in enumerate(self._iter_frames()):
if frame.pts * frame.time_base >= time_:
video_frames.append(frame)
if count >= self.num_frames - 1:
break
count += 1
return video_frames
def _iter_frames(self):
for packet in self.container.demux(self.video_stream):
for frame in packet.decode():
yield frame
def _compute_video_stats(self):
if self.video_stream is None or self.container is None:
return 0
num_of_frames = self.container.streams.video[0].frames
if num_of_frames == 0:
num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base)
self.seek(0, backward=False)
count = 0
time_base = 512
for p in self.container.decode(video=0):
count = count + 1
if count == 1:
start_pts = p.pts
elif count == 2:
time_base = p.pts - start_pts
break
return start_pts, time_base, num_of_frames
def _get_video_frame_rate(self):
return float(self.container.streams.video[0].guessed_rate)
def sample(self, debug=False):
if self.container is None:
raise RuntimeError('video stream not found')
sample = dict()
_, _, total_num_frames = self._compute_video_stats()
offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item()
video_frames = self._read_video(offset/total_num_frames)
video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
sample["frames"] = video_frames
sample["frame_idx"] = [offset]
if self.bi_frame:
frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)]
frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)]
frames.sort()
video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]])
Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)]
sample["frames"] = video_frames
sample["real_t"] = torch.tensor(Ts, dtype=torch.float32)
sample["frame_idx"] = [offset+min(frames), offset+max(frames)]
return sample
return sample
def read_frames(self, frame_indices):
self.num_frames = frame_indices[1] - frame_indices[0]
video_frames = self._read_video(frame_indices[0]/self.get_num_frames())
video_frames = np.array([
np.uint8(video_frames[0].to_rgb().to_ndarray()),
np.uint8(video_frames[-1].to_rgb().to_ndarray())
])
return video_frames
def read(self):
video_frames = self._read_video(0)
video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
return video_frames
def get_num_frames(self):
_, _, total_num_frames = self._compute_video_stats()
return total_num_frames