Spaces:
Runtime error
Runtime error
from typing import List, Dict, Any, Union, Optional | |
import torch | |
from torch.utils.data import DataLoader, ConcatDataset | |
import datasets | |
from diffusers import DDPMScheduler | |
from functools import partial | |
import random | |
import numpy as np | |
def collate_fn( | |
batch: List[Dict[str, Any]], | |
noise_scheduler: DDPMScheduler, | |
num_frames: int, | |
hint_spacing: Optional[int] = None, | |
as_numpy: bool = True | |
) -> Dict[str, Union[torch.Tensor, np.ndarray]]: | |
if hint_spacing is None or hint_spacing < 1: | |
hint_spacing = num_frames | |
if as_numpy: | |
dtype = np.float32 | |
else: | |
dtype = torch.float32 | |
prompts = [] | |
videos = [] | |
for s in batch: | |
# prompt | |
prompts.append(torch.tensor(s['prompt']).to(dtype = torch.float32)) | |
# frames | |
frames = torch.tensor(s['video']).to(dtype = torch.float32) | |
max_frames = len(frames) | |
assert max_frames >= num_frames | |
video_slice = random.randint(0, max_frames - num_frames) | |
frames = frames[video_slice:video_slice + num_frames] | |
frames = frames.permute(1, 0, 2, 3) # f, c, h, w -> c, f, h, w | |
videos.append(frames) | |
encoder_hidden_states = torch.cat(prompts) # b, 77, 768 | |
latents = torch.stack(videos) # b, c, f, h, w | |
latents = latents * 0.18215 | |
hint_latents = latents[:, :, ::hint_spacing, :, :] | |
hint_latents = hint_latents.repeat_interleave(hint_spacing, 2) | |
#hint_latents = hint_latents[:, :, :num_frames-1, :, :] | |
#input_latents = latents[:, :, 1:, :, :] | |
input_latents = latents | |
noise = torch.randn_like(input_latents) | |
bsz = input_latents.shape[0] | |
timesteps = torch.randint( | |
0, | |
noise_scheduler.config.num_train_timesteps, | |
(bsz,), | |
dtype = torch.int64 | |
) | |
noisy_latents = noise_scheduler.add_noise(input_latents, noise, timesteps) | |
mask = torch.zeros([ | |
noisy_latents.shape[0], | |
1, | |
noisy_latents.shape[2], | |
noisy_latents.shape[3], | |
noisy_latents.shape[4] | |
]) | |
latent_model_input = torch.cat([noisy_latents, mask, hint_latents], dim = 1) | |
latent_model_input = latent_model_input.to(memory_format = torch.contiguous_format) | |
encoder_hidden_states = encoder_hidden_states.to(memory_format = torch.contiguous_format) | |
timesteps = timesteps.to(memory_format = torch.contiguous_format) | |
noise = noise.to(memory_format = torch.contiguous_format) | |
if as_numpy: | |
latent_model_input = latent_model_input.numpy().astype(dtype) | |
encoder_hidden_states = encoder_hidden_states.numpy().astype(dtype) | |
timesteps = timesteps.numpy().astype(np.int32) | |
noise = noise.numpy().astype(dtype) | |
else: | |
latent_model_input = latent_model_input.to(dtype = dtype) | |
encoder_hidden_states = encoder_hidden_states.to(dtype = dtype) | |
noise = noise.to(dtype = dtype) | |
return { | |
'latent_model_input': latent_model_input, | |
'encoder_hidden_states': encoder_hidden_states, | |
'timesteps': timesteps, | |
'noise': noise | |
} | |
def worker_init_fn(worker_id: int): | |
wseed = torch.initial_seed() % 4294967294 # max val for random 2**32 - 1 | |
random.seed(wseed) | |
np.random.seed(wseed) | |
def load_dataset( | |
dataset_path: str, | |
model_path: str, | |
cache_dir: Optional[str] = None, | |
batch_size: int = 1, | |
num_frames: int = 24, | |
hint_spacing: Optional[int] = None, | |
num_workers: int = 0, | |
shuffle: bool = False, | |
as_numpy: bool = True, | |
pin_memory: bool = False, | |
pin_memory_device: str = '' | |
) -> DataLoader: | |
noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( | |
model_path, | |
subfolder = 'scheduler' | |
) | |
dataset = datasets.load_dataset( | |
dataset_path, | |
streaming = False, | |
cache_dir = cache_dir | |
) | |
merged_dataset = ConcatDataset([ dataset[s] for s in dataset ]) | |
dataloader = DataLoader( | |
merged_dataset, | |
batch_size = batch_size, | |
num_workers = num_workers, | |
persistent_workers = num_workers > 0, | |
drop_last = True, | |
shuffle = shuffle, | |
worker_init_fn = worker_init_fn, | |
collate_fn = partial(collate_fn, | |
noise_scheduler = noise_scheduler, | |
num_frames = num_frames, | |
hint_spacing = hint_spacing, | |
as_numpy = as_numpy | |
), | |
pin_memory = pin_memory, | |
pin_memory_device = pin_memory_device | |
) | |
return dataloader | |
def validate_dataset( | |
dataset_path: str | |
) -> List[int]: | |
import os | |
import json | |
data_path = os.path.join(dataset_path, 'data') | |
meta = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'metadata'))) | |
prompts = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'prompts'))) | |
videos = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'videos'))) | |
ok = meta.intersection(prompts).intersection(videos) | |
all_of_em = meta.union(prompts).union(videos) | |
not_ok = [] | |
for a in all_of_em: | |
if a not in ok: | |
not_ok.append(a) | |
ok = list(ok) | |
ok.sort() | |
with open(os.path.join(data_path, 'id_list.json'), 'w') as f: | |
json.dump(ok, f) | |