fffiloni's picture
Upload 244 files
b3f324b verified
raw
history blame
No virus
3.56 kB
import math
import os
from glob import glob
import decord
import numpy as np
import torch
import torchvision
from decord import VideoReader, cpu
from torch.utils.data import Dataset
from torchvision.transforms import Compose, Lambda, ToTensor
from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo
from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample
from torch.nn import functional as F
import random
from opensora.utils.dataset_utils import DecordInit
class Landscope(Dataset):
def __init__(self, args, transform, temporal_sample):
self.data_path = args.data_path
self.num_frames = args.num_frames
self.transform = transform
self.temporal_sample = temporal_sample
self.v_decoder = DecordInit()
self.samples = self._make_dataset()
self.use_image_num = args.use_image_num
self.use_img_from_vid = args.use_img_from_vid
if self.use_image_num != 0 and not self.use_img_from_vid:
self.img_cap_list = self.get_img_cap_list()
def _make_dataset(self):
paths = list(glob(os.path.join(self.data_path, '**', '*.mp4'), recursive=True))
return paths
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
video_path = self.samples[idx]
try:
video = self.tv_read(video_path)
video = self.transform(video) # T C H W -> T C H W
video = video.transpose(0, 1) # T C H W -> C T H W
if self.use_image_num != 0 and self.use_img_from_vid:
select_image_idx = np.linspace(0, self.num_frames - 1, self.use_image_num, dtype=int)
assert self.num_frames >= self.use_image_num
images = video[:, select_image_idx] # c, num_img, h, w
video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
elif self.use_image_num != 0 and not self.use_img_from_vid:
images, captions = self.img_cap_list[idx]
raise NotImplementedError
else:
pass
return video, 1
except Exception as e:
print(f'Error with {e}, {video_path}')
return self.__getitem__(random.randint(0, self.__len__()-1))
def tv_read(self, path):
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
# assert end_frame_ind - start_frame_ind >= self.num_frames
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
video = vframes[frame_indice] # (T, C, H, W)
return video
def decord_read(self, path):
decord_vr = self.v_decoder(path)
total_frames = len(decord_vr)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
# assert end_frame_ind - start_frame_ind >= self.num_frames
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
video_data = decord_vr.get_batch(frame_indice).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
return video_data
def get_img_cap_list(self):
raise NotImplementedError