UniVTG / run_on_video /video_loader.py
KevinQHLin's picture
Upload 60 files
9d0a4ae
import torch as th
from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import ffmpeg
import math
def convert_to_float(frac_str):
try:
return float(frac_str)
except ValueError:
try:
num, denom = frac_str.split('/')
except ValueError:
return None
try:
leading, num = num.split(' ')
except ValueError:
return float(num) / float(denom)
if float(leading) < 0:
sign_mult = -1
else:
sign_mult = 1
return float(leading) + sign_mult * (float(num) / float(denom))
class VideoLoader(Dataset):
"""Pytorch video loader."""
def __init__(
self,
vid_path,
framerate=1,
size=112,
centercrop=False,
overwrite=False,
model_version="ViT-B/32",
):
"""
Args:
"""
self.vid_path = vid_path
self.centercrop = centercrop
self.size = size
self.framerate = framerate
self.overwrite = overwrite
self.model_version = model_version
def __len__(self):
return 1
def _get_video_info(self, video_path):
probe = ffmpeg.probe(video_path)
video_stream = next((stream for stream in probe['streams']
if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
fps = math.floor(convert_to_float(video_stream['avg_frame_rate']))
try:
frames_length = int(video_stream['nb_frames'])
duration = float(video_stream['duration'])
except Exception:
frames_length, duration = -1, -1
info = {"duration": duration, "frames_length": frames_length,
"fps": fps, "height": height, "width": width}
return info
def _get_output_dim(self, h, w):
if isinstance(self.size, tuple) and len(self.size) == 2:
return self.size
elif h >= w:
return int(h * self.size / w), self.size
else:
return self.size, int(w * self.size / h)
def __getitem__(self, id):
video_path = self.vid_path
load_flag = os.path.isfile(video_path)
if load_flag:
try:
info = self._get_video_info(video_path)
h, w = info["height"], info["width"]
except Exception:
print('ffprobe failed at: {}'.format(video_path))
return {'video': th.zeros(1), 'input': video_path,'info': {}}
try:
height, width = self._get_output_dim(h, w)
try:
duration = info["duration"]
fps = self.framerate
if duration > 0 and duration < 1/fps+0.1:
fps = 2/max(int(duration), 1)
print(duration, fps)
except Exception:
fps = self.framerate
cmd = (
ffmpeg
.input(video_path)
.filter('fps', fps=fps)
.filter('scale', width, height)
# .filter('scale', self.size, self.size)
)
if self.centercrop:
x = int((width - self.size) / 2.0)
y = int((height - self.size) / 2.0)
cmd = cmd.crop(x, y, self.size, self.size)
out, _ = (
cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
if self.centercrop and isinstance(self.size, int):
height, width = self.size, self.size
video = np.frombuffer(out, np.uint8).reshape(
[-1, height, width, 3])
video = th.from_numpy(video.astype('float32'))
video = video.permute(0, 3, 1, 2)
except:
return {'video': th.zeros(1), 'input': video_path,'info': {}}
else:
video = th.zeros(1)
return {'video': video, 'input': video_path}