Spaces:
Runtime error
Runtime error
import decord | |
decord.bridge.set_bridge('torch') | |
from torch.utils.data import Dataset | |
from einops import rearrange | |
import os | |
from PIL import Image | |
import numpy as np | |
class TuneAVideoDataset(Dataset): | |
def __init__( | |
self, | |
video_path: str, | |
prompt: str, | |
width: int = 512, | |
height: int = 512, | |
n_sample_frames: int = 8, | |
sample_start_idx: int = 0, | |
sample_frame_rate: int = 1, | |
): | |
self.video_path = video_path | |
self.prompt = prompt | |
self.prompt_ids = None | |
self.uncond_prompt_ids = None | |
self.width = width | |
self.height = height | |
self.n_sample_frames = n_sample_frames | |
self.sample_start_idx = sample_start_idx | |
self.sample_frame_rate = sample_frame_rate | |
if 'mp4' not in self.video_path: | |
self.images = [] | |
for file in sorted(os.listdir(self.video_path), key=lambda x: int(x[:-4])): | |
if file.endswith('jpg'): | |
self.images.append(np.asarray(Image.open(os.path.join(self.video_path, file)).convert('RGB').resize((self.width, self.height)))) | |
self.images = np.stack(self.images) | |
def __len__(self): | |
return 1 | |
def __getitem__(self, index): | |
# load and sample video frames | |
if 'mp4' in self.video_path: | |
vr = decord.VideoReader(self.video_path, width=self.width, height=self.height) | |
sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames] | |
video = vr.get_batch(sample_index) | |
else: | |
video = self.images[:self.n_sample_frames] | |
video = rearrange(video, "f h w c -> f c h w") | |
example = { | |
"pixel_values": (video / 127.5 - 1.0), | |
"prompt_ids": self.prompt_ids, | |
} | |
return example | |