AniDoc / datasets /sakuga_ref_datasets.py
fffiloni's picture
Migrated from GitHub
c705408 verified
import os
from tracemalloc import start
import warnings
import glob
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision
import torch.distributed as dist
from decord import VideoReader
from pcache_fileio import fileio
from pcache_fileio.oss_conf import OssConfigFactory
class SakugaRefDataset(Dataset):
def __init__(
self,
# width=1024, height=576,
video_frames=25,
ref_jump_frames=36,
base_folder='data/samples/',
file_list=None,
temporal_sample=None,
transform=None,
seed=42,
):
"""
Args:
num_samples (int): Number of samples in the dataset.
channels (int): Number of channels, default is 3 for RGB.
"""
# Define the path to the folder containing video frames
# self.base_folder = 'bdd100k/images/track/mini'
self.base_folder = base_folder
self.file_list = file_list
if file_list is None:
self.video_lists = glob.glob(os.path.join(self.base_folder, '*.mp4'))
else:
# read from file_list.txt
self.video_lists = []
with open(file_list, 'r') as f:
for line in f:
video_path = line.strip()
self.video_lists.append(os.path.join(self.base_folder, video_path))
self.num_samples = len(self.video_lists)
self.channels = 3
# self.width = width
# self.height = height
self.video_frames = video_frames
self.ref_jump_frames = ref_jump_frames
self.temporal_sample = temporal_sample
self.transform = transform
self.seed = seed
def __len__(self):
return self.num_samples
def get_sample(self, idx):
"""
Args:
idx (int): Index of the sample to return.
Returns:
dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512).
"""
# path = random.choice(self.video_lists)
path = self.video_lists[idx]
if self.file_list is not None: # read from pcache
with open(path, 'rb') as f:
vframes = VideoReader(f)
else:
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
total_frames = len(vframes)
# Sampling video frames
ref_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
if not end_frame_ind - ref_frame_ind >= self.video_frames+self.ref_jump_frames:
raise ValueError(f'video {path} does not have enough frames')
start_frame_ind = ref_frame_ind + self.ref_jump_frames
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.video_frames, dtype=int)
frame_indice = np.insert(frame_indice, 0, ref_frame_ind)
if self.file_list is not None: # read from pcache
video = torch.from_numpy(vframes.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous()
else:
video = vframes[frame_indice]
# (f c h w)
pixel_values = self.transform(video)
return {'pixel_values': pixel_values} # the [0] index for pixel_values is the reference image, the other indexes are the video frames
def __getitem__(self, idx):
# return self.get_sample(idx)
while(True):
try:
# idx = np.random.randint(0, len(self.video_lists) - 1)
# idx = self.rng.integers(0, len(self.video_lists))
item = self.get_sample(idx)
return item
except:
# warnings.warn(f'loading {idx} failed, retrying...')
idx = np.random.randint(0, len(self.video_lists) - 1)
# item = self.get_sample(idx)
# return item