wtx-mmlab
init commit
069c5f0
raw
history blame
No virus
3.6 kB
import os, io, csv, math, random
import numpy as np
from einops import rearrange
from decord import VideoReader
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from animatediff.utils.util import zero_rank_print
class WebVid10M(Dataset):
def __init__(
self,
csv_path, video_folder,
sample_size=256, sample_stride=4, sample_n_frames=16,
is_image=False,
):
zero_rank_print(f"loading annotations from {csv_path} ...")
with open(csv_path, 'r') as csvfile:
self.dataset = list(csv.DictReader(csvfile))
self.length = len(self.dataset)
zero_rank_print(f"data scale: {self.length}")
self.video_folder = video_folder
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.is_image = is_image
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.pixel_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(sample_size[0]),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
def get_batch(self, idx):
video_dict = self.dataset[idx]
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
video_reader = VideoReader(video_dir)
video_length = len(video_reader)
if not self.is_image:
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
else:
batch_index = [random.randint(0, video_length - 1)]
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader
if self.is_image:
pixel_values = pixel_values[0]
return pixel_values, name
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
pixel_values, name = self.get_batch(idx)
break
except Exception as e:
idx = random.randint(0, self.length-1)
pixel_values = self.pixel_transforms(pixel_values)
sample = dict(pixel_values=pixel_values, text=name)
return sample
if __name__ == "__main__":
from animatediff.utils.util import save_videos_grid
dataset = WebVid10M(
csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
sample_size=256,
sample_stride=4, sample_n_frames=16,
is_image=True,
)
import pdb
pdb.set_trace()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
for idx, batch in enumerate(dataloader):
print(batch["pixel_values"].shape, len(batch["text"]))
# for i in range(batch["pixel_values"].shape[0]):
# save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)