lopho's picture
forgot about the nested package structure
b2f876f
raw
history blame
5.39 kB
from typing import List, Dict, Any, Union, Optional
import torch
from torch.utils.data import DataLoader, ConcatDataset
import datasets
from diffusers import DDPMScheduler
from functools import partial
import random
import numpy as np
@torch.no_grad()
def collate_fn(
batch: List[Dict[str, Any]],
noise_scheduler: DDPMScheduler,
num_frames: int,
hint_spacing: Optional[int] = None,
as_numpy: bool = True
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
if hint_spacing is None or hint_spacing < 1:
hint_spacing = num_frames
if as_numpy:
dtype = np.float32
else:
dtype = torch.float32
prompts = []
videos = []
for s in batch:
# prompt
prompts.append(torch.tensor(s['prompt']).to(dtype = torch.float32))
# frames
frames = torch.tensor(s['video']).to(dtype = torch.float32)
max_frames = len(frames)
assert max_frames >= num_frames
video_slice = random.randint(0, max_frames - num_frames)
frames = frames[video_slice:video_slice + num_frames]
frames = frames.permute(1, 0, 2, 3) # f, c, h, w -> c, f, h, w
videos.append(frames)
encoder_hidden_states = torch.cat(prompts) # b, 77, 768
latents = torch.stack(videos) # b, c, f, h, w
latents = latents * 0.18215
hint_latents = latents[:, :, ::hint_spacing, :, :]
hint_latents = hint_latents.repeat_interleave(hint_spacing, 2)
#hint_latents = hint_latents[:, :, :num_frames-1, :, :]
#input_latents = latents[:, :, 1:, :, :]
input_latents = latents
noise = torch.randn_like(input_latents)
bsz = input_latents.shape[0]
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(bsz,),
dtype = torch.int64
)
noisy_latents = noise_scheduler.add_noise(input_latents, noise, timesteps)
mask = torch.zeros([
noisy_latents.shape[0],
1,
noisy_latents.shape[2],
noisy_latents.shape[3],
noisy_latents.shape[4]
])
latent_model_input = torch.cat([noisy_latents, mask, hint_latents], dim = 1)
latent_model_input = latent_model_input.to(memory_format = torch.contiguous_format)
encoder_hidden_states = encoder_hidden_states.to(memory_format = torch.contiguous_format)
timesteps = timesteps.to(memory_format = torch.contiguous_format)
noise = noise.to(memory_format = torch.contiguous_format)
if as_numpy:
latent_model_input = latent_model_input.numpy().astype(dtype)
encoder_hidden_states = encoder_hidden_states.numpy().astype(dtype)
timesteps = timesteps.numpy().astype(np.int32)
noise = noise.numpy().astype(dtype)
else:
latent_model_input = latent_model_input.to(dtype = dtype)
encoder_hidden_states = encoder_hidden_states.to(dtype = dtype)
noise = noise.to(dtype = dtype)
return {
'latent_model_input': latent_model_input,
'encoder_hidden_states': encoder_hidden_states,
'timesteps': timesteps,
'noise': noise
}
def worker_init_fn(worker_id: int):
wseed = torch.initial_seed() % 4294967294 # max val for random 2**32 - 1
random.seed(wseed)
np.random.seed(wseed)
def load_dataset(
dataset_path: str,
model_path: str,
cache_dir: Optional[str] = None,
batch_size: int = 1,
num_frames: int = 24,
hint_spacing: Optional[int] = None,
num_workers: int = 0,
shuffle: bool = False,
as_numpy: bool = True,
pin_memory: bool = False,
pin_memory_device: str = ''
) -> DataLoader:
noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(
model_path,
subfolder = 'scheduler'
)
dataset = datasets.load_dataset(
dataset_path,
streaming = False,
cache_dir = cache_dir
)
merged_dataset = ConcatDataset([ dataset[s] for s in dataset ])
dataloader = DataLoader(
merged_dataset,
batch_size = batch_size,
num_workers = num_workers,
persistent_workers = num_workers > 0,
drop_last = True,
shuffle = shuffle,
worker_init_fn = worker_init_fn,
collate_fn = partial(collate_fn,
noise_scheduler = noise_scheduler,
num_frames = num_frames,
hint_spacing = hint_spacing,
as_numpy = as_numpy
),
pin_memory = pin_memory,
pin_memory_device = pin_memory_device
)
return dataloader
def validate_dataset(
dataset_path: str
) -> List[int]:
import os
import json
data_path = os.path.join(dataset_path, 'data')
meta = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'metadata')))
prompts = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'prompts')))
videos = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'videos')))
ok = meta.intersection(prompts).intersection(videos)
all_of_em = meta.union(prompts).union(videos)
not_ok = []
for a in all_of_em:
if a not in ok:
not_ok.append(a)
ok = list(ok)
ok.sort()
with open(os.path.join(data_path, 'id_list.json'), 'w') as f:
json.dump(ok, f)