Spaces:
Sleeping
Sleeping
import os, io, csv, math, random | |
import numpy as np | |
from einops import rearrange | |
from decord import VideoReader | |
import torch | |
import torchvision.transforms as transforms | |
from torch.utils.data.dataset import Dataset | |
from animatediff.utils.util import zero_rank_print | |
class WebVid10M(Dataset): | |
def __init__( | |
self, | |
csv_path, video_folder, | |
sample_size=256, sample_stride=4, sample_n_frames=16, | |
is_image=False, | |
): | |
zero_rank_print(f"loading annotations from {csv_path} ...") | |
with open(csv_path, 'r') as csvfile: | |
self.dataset = list(csv.DictReader(csvfile)) | |
self.length = len(self.dataset) | |
zero_rank_print(f"data scale: {self.length}") | |
self.video_folder = video_folder | |
self.sample_stride = sample_stride | |
self.sample_n_frames = sample_n_frames | |
self.is_image = is_image | |
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) | |
self.pixel_transforms = transforms.Compose([ | |
transforms.RandomHorizontalFlip(), | |
transforms.Resize(sample_size[0]), | |
transforms.CenterCrop(sample_size), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
]) | |
def get_batch(self, idx): | |
video_dict = self.dataset[idx] | |
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] | |
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") | |
video_reader = VideoReader(video_dir) | |
video_length = len(video_reader) | |
if not self.is_image: | |
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) | |
start_idx = random.randint(0, video_length - clip_length) | |
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) | |
else: | |
batch_index = [random.randint(0, video_length - 1)] | |
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
pixel_values = pixel_values / 255. | |
del video_reader | |
if self.is_image: | |
pixel_values = pixel_values[0] | |
return pixel_values, name | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx): | |
while True: | |
try: | |
pixel_values, name = self.get_batch(idx) | |
break | |
except Exception as e: | |
idx = random.randint(0, self.length-1) | |
pixel_values = self.pixel_transforms(pixel_values) | |
sample = dict(pixel_values=pixel_values, text=name) | |
return sample | |
if __name__ == "__main__": | |
from animatediff.utils.util import save_videos_grid | |
dataset = WebVid10M( | |
csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv", | |
video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val", | |
sample_size=256, | |
sample_stride=4, sample_n_frames=16, | |
is_image=True, | |
) | |
import pdb | |
pdb.set_trace() | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,) | |
for idx, batch in enumerate(dataloader): | |
print(batch["pixel_values"].shape, len(batch["text"])) | |
# for i in range(batch["pixel_values"].shape[0]): | |
# save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True) | |