import json import random from typing import List import numpy as np import pandas as pd import torch import torchvision.transforms as transforms from decord import VideoReader from PIL import Image from torch.utils.data import Dataset from transformers import CLIPImageProcessor class HumanDanceVideoDataset(Dataset): def __init__( self, sample_rate, n_sample_frames, width, height, img_scale=(1.0, 1.0), img_ratio=(0.9, 1.0), drop_ratio=0.1, data_meta_paths=["./data/fashion_meta.json"], ): super().__init__() self.sample_rate = sample_rate self.n_sample_frames = n_sample_frames self.width = width self.height = height self.img_scale = img_scale self.img_ratio = img_ratio vid_meta = [] for data_meta_path in data_meta_paths: vid_meta.extend(json.load(open(data_meta_path, "r"))) self.vid_meta = vid_meta self.clip_image_processor = CLIPImageProcessor() self.pixel_transform = transforms.Compose( [ # transforms.RandomResizedCrop( # (height, width), # scale=self.img_scale, # ratio=self.img_ratio, # interpolation=transforms.InterpolationMode.BILINEAR, # ), transforms.Resize( (height, width), ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) self.cond_transform = transforms.Compose( [ # transforms.RandomResizedCrop( # (height, width), # scale=self.img_scale, # ratio=self.img_ratio, # interpolation=transforms.InterpolationMode.BILINEAR, # ), transforms.Resize( (height, width), ), transforms.ToTensor(), ] ) self.drop_ratio = drop_ratio def augmentation(self, images, transform, state=None): if state is not None: torch.set_rng_state(state) if isinstance(images, List): transformed_images = [transform(img) for img in images] ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) else: ret_tensor = transform(images) # (c, h, w) return ret_tensor def __getitem__(self, index): video_meta = self.vid_meta[index] video_path = video_meta["video_path"] kps_path = video_meta["kps_path"] video_reader = VideoReader(video_path) kps_reader = VideoReader(kps_path) assert len(video_reader) == len( kps_reader ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}" video_length = len(video_reader) video_fps = video_reader.get_avg_fps() # print("fps", video_fps) if video_fps > 30: # 30-60 sample_rate = self.sample_rate*2 else: sample_rate = self.sample_rate clip_length = min( video_length, (self.n_sample_frames - 1) * sample_rate + 1 ) start_idx = random.randint(0, video_length - clip_length) batch_index = np.linspace( start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int ).tolist() # read frames and kps vid_pil_image_list = [] pose_pil_image_list = [] for index in batch_index: img = video_reader[index] vid_pil_image_list.append(Image.fromarray(img.asnumpy())) img = kps_reader[index] pose_pil_image_list.append(Image.fromarray(img.asnumpy())) ref_img_idx = random.randint(0, video_length - 1) ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy()) # transform state = torch.get_rng_state() pixel_values_vid = self.augmentation( vid_pil_image_list, self.pixel_transform, state ) pixel_values_pose = self.augmentation( pose_pil_image_list, self.cond_transform, state ) pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state) clip_ref_img = self.clip_image_processor( images=ref_img, return_tensors="pt" ).pixel_values[0] sample = dict( video_dir=video_path, pixel_values_vid=pixel_values_vid, pixel_values_pose=pixel_values_pose, pixel_values_ref_img=pixel_values_ref_img, clip_ref_img=clip_ref_img, ) return sample def __len__(self): return len(self.vid_meta)