import os from tracemalloc import start import warnings import glob import random import numpy as np from PIL import Image import torch from torch.utils.data import Dataset import torchvision import torch.distributed as dist from decord import VideoReader from pcache_fileio import fileio from pcache_fileio.oss_conf import OssConfigFactory class SakugaRefDataset(Dataset): def __init__( self, # width=1024, height=576, video_frames=25, ref_jump_frames=36, base_folder='data/samples/', file_list=None, temporal_sample=None, transform=None, seed=42, ): """ Args: num_samples (int): Number of samples in the dataset. channels (int): Number of channels, default is 3 for RGB. """ # Define the path to the folder containing video frames # self.base_folder = 'bdd100k/images/track/mini' self.base_folder = base_folder self.file_list = file_list if file_list is None: self.video_lists = glob.glob(os.path.join(self.base_folder, '*.mp4')) else: # read from file_list.txt self.video_lists = [] with open(file_list, 'r') as f: for line in f: video_path = line.strip() self.video_lists.append(os.path.join(self.base_folder, video_path)) self.num_samples = len(self.video_lists) self.channels = 3 # self.width = width # self.height = height self.video_frames = video_frames self.ref_jump_frames = ref_jump_frames self.temporal_sample = temporal_sample self.transform = transform self.seed = seed def __len__(self): return self.num_samples def get_sample(self, idx): """ Args: idx (int): Index of the sample to return. Returns: dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512). """ # path = random.choice(self.video_lists) path = self.video_lists[idx] if self.file_list is not None: # read from pcache with open(path, 'rb') as f: vframes = VideoReader(f) else: vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') total_frames = len(vframes) # Sampling video frames ref_frame_ind, end_frame_ind = self.temporal_sample(total_frames) if not end_frame_ind - ref_frame_ind >= self.video_frames+self.ref_jump_frames: raise ValueError(f'video {path} does not have enough frames') start_frame_ind = ref_frame_ind + self.ref_jump_frames frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.video_frames, dtype=int) frame_indice = np.insert(frame_indice, 0, ref_frame_ind) if self.file_list is not None: # read from pcache video = torch.from_numpy(vframes.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous() else: video = vframes[frame_indice] # (f c h w) pixel_values = self.transform(video) return {'pixel_values': pixel_values} # the [0] index for pixel_values is the reference image, the other indexes are the video frames def __getitem__(self, idx): # return self.get_sample(idx) while(True): try: # idx = np.random.randint(0, len(self.video_lists) - 1) # idx = self.rng.integers(0, len(self.video_lists)) item = self.get_sample(idx) return item except: # warnings.warn(f'loading {idx} failed, retrying...') idx = np.random.randint(0, len(self.video_lists) - 1) # item = self.get_sample(idx) # return item