Spaces:
Runtime error
Runtime error
File size: 8,128 Bytes
0366b8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import os
import random
from tqdm import tqdm
import pandas as pd
from decord import VideoReader, cpu
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
class WebVid(Dataset):
"""
WebVid Dataset.
Assumes webvid data is structured as follows.
Webvid/
videos/
000001_000050/ ($page_dir)
1.mp4 (videoid.mp4)
...
5000.mp4
...
"""
def __init__(self,
meta_path,
data_dir,
subsample=None,
video_length=16,
resolution=[256, 512],
frame_stride=1,
frame_stride_min=1,
spatial_transform=None,
crop_resolution=None,
fps_max=None,
load_raw_resolution=False,
fixed_fps=None,
random_fs=False,
):
self.meta_path = meta_path
self.data_dir = data_dir
self.subsample = subsample
self.video_length = video_length
self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution
self.fps_max = fps_max
self.frame_stride = frame_stride
self.frame_stride_min = frame_stride_min
self.fixed_fps = fixed_fps
self.load_raw_resolution = load_raw_resolution
self.random_fs = random_fs
self._load_metadata()
if spatial_transform is not None:
if spatial_transform == "random_crop":
self.spatial_transform = transforms.RandomCrop(crop_resolution)
elif spatial_transform == "center_crop":
self.spatial_transform = transforms.Compose([
transforms.CenterCrop(resolution),
])
elif spatial_transform == "resize_center_crop":
# assert(self.resolution[0] == self.resolution[1])
self.spatial_transform = transforms.Compose([
transforms.Resize(min(self.resolution)),
transforms.CenterCrop(self.resolution),
])
elif spatial_transform == "resize":
self.spatial_transform = transforms.Resize(self.resolution)
else:
raise NotImplementedError
else:
self.spatial_transform = None
def _load_metadata(self):
metadata = pd.read_csv(self.meta_path)
print(f'>>> {len(metadata)} data samples loaded.')
if self.subsample is not None:
metadata = metadata.sample(self.subsample, random_state=0)
metadata['caption'] = metadata['name']
del metadata['name']
self.metadata = metadata
self.metadata.dropna(inplace=True)
def _get_video_path(self, sample):
rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
return full_video_fp
def __getitem__(self, index):
if self.random_fs:
frame_stride = random.randint(self.frame_stride_min, self.frame_stride)
else:
frame_stride = self.frame_stride
## get frames until success
while True:
index = index % len(self.metadata)
sample = self.metadata.iloc[index]
video_path = self._get_video_path(sample)
## video_path should be in the format of "....../WebVid/videos/$page_dir/$videoid.mp4"
caption = sample['caption']
try:
if self.load_raw_resolution:
video_reader = VideoReader(video_path, ctx=cpu(0))
else:
video_reader = VideoReader(video_path, ctx=cpu(0), width=530, height=300)
if len(video_reader) < self.video_length:
print(f"video length ({len(video_reader)}) is smaller than target length({self.video_length})")
index += 1
continue
else:
pass
except:
index += 1
print(f"Load video failed! path = {video_path}")
continue
fps_ori = video_reader.get_avg_fps()
if self.fixed_fps is not None:
frame_stride = int(frame_stride * (1.0 * fps_ori / self.fixed_fps))
## to avoid extreme cases when fixed_fps is used
frame_stride = max(frame_stride, 1)
## get valid range (adapting case by case)
required_frame_num = frame_stride * (self.video_length-1) + 1
frame_num = len(video_reader)
if frame_num < required_frame_num:
## drop extra samples if fixed fps is required
if self.fixed_fps is not None and frame_num < required_frame_num * 0.5:
index += 1
continue
else:
frame_stride = frame_num // self.video_length
required_frame_num = frame_stride * (self.video_length-1) + 1
## select a random clip
random_range = frame_num - required_frame_num
start_idx = random.randint(0, random_range) if random_range > 0 else 0
## calculate frame indices
frame_indices = [start_idx + frame_stride*i for i in range(self.video_length)]
try:
frames = video_reader.get_batch(frame_indices)
break
except:
print(f"Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]")
index += 1
continue
## process data
assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}'
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
if self.spatial_transform is not None:
frames = self.spatial_transform(frames)
if self.resolution is not None:
assert (frames.shape[2], frames.shape[3]) == (self.resolution[0], self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}'
## turn frames tensors to [-1,1]
frames = (frames / 255 - 0.5) * 2
fps_clip = fps_ori // frame_stride
if self.fps_max is not None and fps_clip > self.fps_max:
fps_clip = self.fps_max
data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': fps_clip, 'frame_stride': frame_stride}
return data
def __len__(self):
return len(self.metadata)
if __name__== "__main__":
meta_path = "" ## path to the meta file
data_dir = "" ## path to the data directory
save_dir = "" ## path to the save directory
dataset = WebVid(meta_path,
data_dir,
subsample=None,
video_length=16,
resolution=[256,448],
frame_stride=4,
spatial_transform="resize_center_crop",
crop_resolution=None,
fps_max=None,
load_raw_resolution=True
)
dataloader = DataLoader(dataset,
batch_size=1,
num_workers=0,
shuffle=False)
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
from utils.save_video import tensor_to_mp4
for i, batch in tqdm(enumerate(dataloader), desc="Data Batch"):
video = batch['video']
name = batch['path'][0].split('videos/')[-1].replace('/','_')
tensor_to_mp4(video, save_dir+'/'+name, fps=8)
|