import os import random from tqdm import tqdm import pandas as pd from decord import VideoReader, cpu import torch from torch.utils.data import Dataset from torch.utils.data import DataLoader from torchvision import transforms class WebVid(Dataset): """ WebVid Dataset. Assumes webvid data is structured as follows. Webvid/ videos/ 000001_000050/ ($page_dir) 1.mp4 (videoid.mp4) ... 5000.mp4 ... """ def __init__(self, meta_path, data_dir, subsample=None, video_length=16, resolution=[256, 512], frame_stride=1, frame_stride_min=1, spatial_transform=None, crop_resolution=None, fps_max=None, load_raw_resolution=False, fixed_fps=None, random_fs=False, ): self.meta_path = meta_path self.data_dir = data_dir self.subsample = subsample self.video_length = video_length self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution self.fps_max = fps_max self.frame_stride = frame_stride self.frame_stride_min = frame_stride_min self.fixed_fps = fixed_fps self.load_raw_resolution = load_raw_resolution self.random_fs = random_fs self._load_metadata() if spatial_transform is not None: if spatial_transform == "random_crop": self.spatial_transform = transforms.RandomCrop(crop_resolution) elif spatial_transform == "center_crop": self.spatial_transform = transforms.Compose([ transforms.CenterCrop(resolution), ]) elif spatial_transform == "resize_center_crop": # assert(self.resolution[0] == self.resolution[1]) self.spatial_transform = transforms.Compose([ transforms.Resize(min(self.resolution)), transforms.CenterCrop(self.resolution), ]) elif spatial_transform == "resize": self.spatial_transform = transforms.Resize(self.resolution) else: raise NotImplementedError else: self.spatial_transform = None def _load_metadata(self): metadata = pd.read_csv(self.meta_path) print(f'>>> {len(metadata)} data samples loaded.') if self.subsample is not None: metadata = metadata.sample(self.subsample, random_state=0) metadata['caption'] = metadata['name'] del metadata['name'] self.metadata = metadata self.metadata.dropna(inplace=True) def _get_video_path(self, sample): rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp) return full_video_fp def __getitem__(self, index): if self.random_fs: frame_stride = random.randint(self.frame_stride_min, self.frame_stride) else: frame_stride = self.frame_stride ## get frames until success while True: index = index % len(self.metadata) sample = self.metadata.iloc[index] video_path = self._get_video_path(sample) ## video_path should be in the format of "....../WebVid/videos/$page_dir/$videoid.mp4" caption = sample['caption'] try: if self.load_raw_resolution: video_reader = VideoReader(video_path, ctx=cpu(0)) else: video_reader = VideoReader(video_path, ctx=cpu(0), width=530, height=300) if len(video_reader) < self.video_length: print(f"video length ({len(video_reader)}) is smaller than target length({self.video_length})") index += 1 continue else: pass except: index += 1 print(f"Load video failed! path = {video_path}") continue fps_ori = video_reader.get_avg_fps() if self.fixed_fps is not None: frame_stride = int(frame_stride * (1.0 * fps_ori / self.fixed_fps)) ## to avoid extreme cases when fixed_fps is used frame_stride = max(frame_stride, 1) ## get valid range (adapting case by case) required_frame_num = frame_stride * (self.video_length-1) + 1 frame_num = len(video_reader) if frame_num < required_frame_num: ## drop extra samples if fixed fps is required if self.fixed_fps is not None and frame_num < required_frame_num * 0.5: index += 1 continue else: frame_stride = frame_num // self.video_length required_frame_num = frame_stride * (self.video_length-1) + 1 ## select a random clip random_range = frame_num - required_frame_num start_idx = random.randint(0, random_range) if random_range > 0 else 0 ## calculate frame indices frame_indices = [start_idx + frame_stride*i for i in range(self.video_length)] try: frames = video_reader.get_batch(frame_indices) break except: print(f"Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]") index += 1 continue ## process data assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}' frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] if self.spatial_transform is not None: frames = self.spatial_transform(frames) if self.resolution is not None: assert (frames.shape[2], frames.shape[3]) == (self.resolution[0], self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}' ## turn frames tensors to [-1,1] frames = (frames / 255 - 0.5) * 2 fps_clip = fps_ori // frame_stride if self.fps_max is not None and fps_clip > self.fps_max: fps_clip = self.fps_max data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': fps_clip, 'frame_stride': frame_stride} return data def __len__(self): return len(self.metadata) if __name__== "__main__": meta_path = "" ## path to the meta file data_dir = "" ## path to the data directory save_dir = "" ## path to the save directory dataset = WebVid(meta_path, data_dir, subsample=None, video_length=16, resolution=[256,448], frame_stride=4, spatial_transform="resize_center_crop", crop_resolution=None, fps_max=None, load_raw_resolution=True ) dataloader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False) import sys sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) from utils.save_video import tensor_to_mp4 for i, batch in tqdm(enumerate(dataloader), desc="Data Batch"): video = batch['video'] name = batch['path'][0].split('videos/')[-1].replace('/','_') tensor_to_mp4(video, save_dir+'/'+name, fps=8)