Spaces:
Running
Running
File size: 6,097 Bytes
8aa9c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
# Copyright 2022 ByteDance and/or its affiliates.
#
# Copyright (2022) PV3D Authors
#
# ByteDance, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
import av, gc
import torch
import warnings
import numpy as np
_CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 20
# remove warnings
av.logging.set_level(av.logging.ERROR)
class VideoReader():
"""
Simple wrapper around PyAV that exposes a few useful functions for
dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries.
Acknowledgement: Codes are borrowed from Bruno Korbar
"""
def __init__(self, video, num_frames=float("inf"), decode_lossy=False, audio_resample_rate=None, bi_frame=False):
"""
Arguments:
video_path (str): path or byte of the video to be loaded
"""
self.container = av.open(video)
self.num_frames = num_frames
self.bi_frame = bi_frame
self.resampler = None
if audio_resample_rate is not None:
self.resampler = av.AudioResampler(rate=audio_resample_rate)
if self.container.streams.video:
# enable multi-threaded video decoding
if decode_lossy:
warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning)
self.container.streams.video[0].thread_type = 'AUTO'
self.video_stream = self.container.streams.video[0]
else:
self.video_stream = None
self.fps = self._get_video_frame_rate()
def seek(self, pts, backward=True, any_frame=False):
stream = self.video_stream
self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream)
def _occasional_gc(self):
# there are a lot of reference cycles in PyAV, so need to manually call
# the garbage collector from time to time
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
gc.collect()
def _read_video(self, offset):
self._occasional_gc()
pts = self.container.duration * offset
time_ = pts / float(av.time_base)
self.container.seek(int(pts))
video_frames = []
count = 0
for _, frame in enumerate(self._iter_frames()):
if frame.pts * frame.time_base >= time_:
video_frames.append(frame)
if count >= self.num_frames - 1:
break
count += 1
return video_frames
def _iter_frames(self):
for packet in self.container.demux(self.video_stream):
for frame in packet.decode():
yield frame
def _compute_video_stats(self):
if self.video_stream is None or self.container is None:
return 0
num_of_frames = self.container.streams.video[0].frames
if num_of_frames == 0:
num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base)
self.seek(0, backward=False)
count = 0
time_base = 512
for p in self.container.decode(video=0):
count = count + 1
if count == 1:
start_pts = p.pts
elif count == 2:
time_base = p.pts - start_pts
break
return start_pts, time_base, num_of_frames
def _get_video_frame_rate(self):
return float(self.container.streams.video[0].guessed_rate)
def sample(self, debug=False):
if self.container is None:
raise RuntimeError('video stream not found')
sample = dict()
_, _, total_num_frames = self._compute_video_stats()
offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item()
video_frames = self._read_video(offset/total_num_frames)
video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
sample["frames"] = video_frames
sample["frame_idx"] = [offset]
if self.bi_frame:
frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)]
frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)]
frames.sort()
video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]])
Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)]
sample["frames"] = video_frames
sample["real_t"] = torch.tensor(Ts, dtype=torch.float32)
sample["frame_idx"] = [offset+min(frames), offset+max(frames)]
return sample
return sample
def read_frames(self, frame_indices):
self.num_frames = frame_indices[1] - frame_indices[0]
video_frames = self._read_video(frame_indices[0]/self.get_num_frames())
video_frames = np.array([
np.uint8(video_frames[0].to_rgb().to_ndarray()),
np.uint8(video_frames[-1].to_rgb().to_ndarray())
])
return video_frames
def read(self):
video_frames = self._read_video(0)
video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
return video_frames
def get_num_frames(self):
_, _, total_num_frames = self._compute_video_stats()
return total_num_frames |