pllava-34b-demo / dataset /video_utils.py
cathyxl
added
f239efc
raw
history blame
7.83 kB
"""
Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
"""
import random
import io
import os
import av
import cv2
import decord
import imageio
from decord import VideoReader
# from dataloader import KVReader
import torch
import numpy as np
import math
# import tensorflow as tf
decord.bridge.set_bridge("torch")
import logging
logger = logging.getLogger(__name__)
def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
"""
Converts a present time with the given time base and start_pts offset to seconds.
Returns:
time_in_seconds (float): The corresponding time in seconds.
https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
"""
if pts == math.inf:
return math.inf
return int(pts - start_pts) * time_base
def get_pyav_video_duration(video_reader):
video_stream = video_reader.streams.video[0]
video_duration = pts_to_secs(
video_stream.duration,
video_stream.time_base,
video_stream.start_time
)
return float(video_duration)
def get_frame_indices_by_fps():
pass
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
if sample in ["rand", "middle"]: # uniform sampling
acc_samples = min(num_frames, vlen)
# split the video into `acc_samples` intervals, and sample from each interval.
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
if sample == 'rand':
try:
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
except:
frame_indices = np.random.permutation(vlen)[:acc_samples]
frame_indices.sort()
frame_indices = list(frame_indices)
elif fix_start is not None:
frame_indices = [x[0] + fix_start for x in ranges]
elif sample == 'middle':
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
else:
raise NotImplementedError
if len(frame_indices) < num_frames: # padded with last frame
padded_frame_indices = [frame_indices[-1]] * num_frames
padded_frame_indices[:len(frame_indices)] = frame_indices
frame_indices = padded_frame_indices
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
output_fps = float(sample[3:])
duration = float(vlen) / input_fps
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
frame_indices = np.around(frame_seconds * input_fps).astype(int)
frame_indices = [e for e in frame_indices if e < vlen]
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
frame_indices = frame_indices[:max_num_frames]
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
else:
raise ValueError
return frame_indices
def read_frames_av(
video_path, num_frames, sample='rand', fix_start=None,
max_num_frames=-1, client=None, clip=None,
):
reader = av.open(video_path)
frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]
vlen = len(frames)
duration = get_pyav_video_duration(reader)
fps = vlen / float(duration)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps, max_num_frames=max_num_frames
)
frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
return frames, frame_indices, fps
def read_frames_gif(
video_path, num_frames, sample='rand', fix_start=None,
max_num_frames=-1, client=None, clip=None,
):
if video_path.startswith('s3') or video_path.startswith('p2'):
video_bytes = client.get(video_path)
gif = imageio.get_reader(io.BytesIO(video_bytes))
else:
gif = imageio.get_reader(video_path)
vlen = len(gif)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
max_num_frames=max_num_frames
)
frames = []
for index, frame in enumerate(gif):
# for index in frame_idxs:
if index in frame_indices:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
frame = torch.from_numpy(frame).byte()
# # (H x W x C) to (C x H x W)
frame = frame.permute(2, 0, 1)
frames.append(frame)
frames = torch.stack(frames) # .float() / 255
return frames, frame_indices, 25. # for tgif
def read_frames_hdfs(ind_file, vid, num_frames, sample='rand',fix_start=None,
max_num_frames=-1, client=None, clip=None):
_context_features = {'title': tf.io.FixedLenFeature([], dtype=tf.string)}
_sequence_features = {'data': tf.io.FixedLenSequenceFeature([], dtype=tf.string)}
num_parallel_reader = 1
filename, extension = os.path.splitext(ind_file)
reader = KVReader(filename, num_parallel_reader)
key = vid
values = reader.read_many([key])
item = values[0]
contexts, sequences = tf.io.parse_single_sequence_example(
serialized=item,
context_features=_context_features,
sequence_features=_sequence_features)
# text = contexts['title'].numpy().decode("utf-8")
rawframes = sequences['data']
vlen = len(rawframes)
sample="rand"
frame_indices = get_frame_indices(num_frames, vlen, sample=sample,
fix_start=fix_start,
max_num_frames=max_num_frames)
def read_image(raw_data):
return tf.image.decode_jpeg(raw_data, channels=3, dct_method='INTEGER_ACCURATE').numpy()
frames = []
for index, frame in enumerate(rawframes):
if index in frame_indices:
frame = read_image(frame)
frame = torch.as_tensor(frame)
frames.append(frame)
frames = torch.stack(frames)
# print("in hdfs========>",frames[0])
frames = frames.permute(0, 3, 1, 2)
return frames, frame_indices, 25 # don't know the fps for index
def read_frames_decord(
video_path, num_frames, sample='rand', fix_start=None,
max_num_frames=-1, client=None, clip=None
):
if video_path.startswith('s3') or video_path.startswith('p2'):
video_bytes = client.get(video_path)
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
else:
video_reader = VideoReader(video_path, num_threads=1)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
if clip:
start, end = clip
duration = end - start
vlen = int(duration * fps)
start_index = int(start * fps)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps, max_num_frames=max_num_frames
)
if clip:
frame_indices = [f + start_index for f in frame_indices]
frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
return frames, frame_indices, float(fps)
VIDEO_READER_FUNCS = {
'av': read_frames_av,
'decord': read_frames_decord,
'gif': read_frames_gif,
'hdfs': read_frames_hdfs,
}