|
import math |
|
import os |
|
from glob import glob |
|
|
|
import decord |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from decord import VideoReader, cpu |
|
from torch.utils.data import Dataset |
|
from torchvision.transforms import Compose, Lambda, ToTensor |
|
from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo |
|
from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample |
|
from torch.nn import functional as F |
|
import random |
|
|
|
from opensora.utils.dataset_utils import DecordInit |
|
|
|
|
|
class Landscope(Dataset): |
|
def __init__(self, args, transform, temporal_sample): |
|
self.data_path = args.data_path |
|
self.num_frames = args.num_frames |
|
self.transform = transform |
|
self.temporal_sample = temporal_sample |
|
self.v_decoder = DecordInit() |
|
|
|
self.samples = self._make_dataset() |
|
self.use_image_num = args.use_image_num |
|
self.use_img_from_vid = args.use_img_from_vid |
|
if self.use_image_num != 0 and not self.use_img_from_vid: |
|
self.img_cap_list = self.get_img_cap_list() |
|
|
|
|
|
def _make_dataset(self): |
|
paths = list(glob(os.path.join(self.data_path, '**', '*.mp4'), recursive=True)) |
|
|
|
return paths |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
video_path = self.samples[idx] |
|
try: |
|
video = self.tv_read(video_path) |
|
video = self.transform(video) |
|
video = video.transpose(0, 1) |
|
if self.use_image_num != 0 and self.use_img_from_vid: |
|
select_image_idx = np.linspace(0, self.num_frames - 1, self.use_image_num, dtype=int) |
|
assert self.num_frames >= self.use_image_num |
|
images = video[:, select_image_idx] |
|
video = torch.cat([video, images], dim=1) |
|
elif self.use_image_num != 0 and not self.use_img_from_vid: |
|
images, captions = self.img_cap_list[idx] |
|
raise NotImplementedError |
|
else: |
|
pass |
|
return video, 1 |
|
except Exception as e: |
|
print(f'Error with {e}, {video_path}') |
|
return self.__getitem__(random.randint(0, self.__len__()-1)) |
|
|
|
def tv_read(self, path): |
|
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') |
|
total_frames = len(vframes) |
|
|
|
|
|
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) |
|
|
|
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) |
|
video = vframes[frame_indice] |
|
|
|
return video |
|
|
|
def decord_read(self, path): |
|
decord_vr = self.v_decoder(path) |
|
total_frames = len(decord_vr) |
|
|
|
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) |
|
|
|
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) |
|
|
|
video_data = decord_vr.get_batch(frame_indice).asnumpy() |
|
video_data = torch.from_numpy(video_data) |
|
video_data = video_data.permute(0, 3, 1, 2) |
|
return video_data |
|
|
|
def get_img_cap_list(self): |
|
raise NotImplementedError |
|
|