import math from contextlib import nullcontext from functools import partial from typing import Dict, List, Optional, Tuple, Union import kornia import numpy as np import open_clip import torch import torch.nn as nn from einops import rearrange, repeat from omegaconf import ListConfig from torch.utils.checkpoint import checkpoint from transformers import ( ByT5Tokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, ) from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer from ...modules.diffusionmodules.model import Encoder from ...modules.diffusionmodules.openaimodel import Timestep from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule from ...modules.distributions.distributions import DiagonalGaussianDistribution from ...util import ( append_dims, autocast, count_params, default, disabled_train, expand_dims_like, instantiate_from_config, ) class AbstractEmbModel(nn.Module): def __init__(self): super().__init__() self._is_trainable = None self._ucg_rate = None self._input_key = None @property def is_trainable(self) -> bool: return self._is_trainable @property def ucg_rate(self) -> Union[float, torch.Tensor]: return self._ucg_rate @property def input_key(self) -> str: return self._input_key @is_trainable.setter def is_trainable(self, value: bool): self._is_trainable = value @ucg_rate.setter def ucg_rate(self, value: Union[float, torch.Tensor]): self._ucg_rate = value @input_key.setter def input_key(self, value: str): self._input_key = value @is_trainable.deleter def is_trainable(self): del self._is_trainable @ucg_rate.deleter def ucg_rate(self): del self._ucg_rate @input_key.deleter def input_key(self): del self._input_key class GeneralConditioner(nn.Module): OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} def __init__(self, emb_models: Union[List, ListConfig]): super().__init__() embedders = [] for n, embconfig in enumerate(emb_models): embedder = instantiate_from_config(embconfig) assert isinstance( embedder, AbstractEmbModel ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" embedder.is_trainable = embconfig.get("is_trainable", False) embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) if not embedder.is_trainable: embedder.train = disabled_train for param in embedder.parameters(): param.requires_grad = False embedder.eval() print( f"Initialized embedder #{n}: {embedder.__class__.__name__} " f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" ) if "input_key" in embconfig: embedder.input_key = embconfig["input_key"] elif "input_keys" in embconfig: embedder.input_keys = embconfig["input_keys"] else: raise KeyError( f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" ) embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) if embedder.legacy_ucg_val is not None: embedder.ucg_prng = np.random.RandomState() embedders.append(embedder) self.embedders = nn.ModuleList(embedders) def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: assert embedder.legacy_ucg_val is not None p = embedder.ucg_rate val = embedder.legacy_ucg_val for i in range(len(batch[embedder.input_key])): if embedder.ucg_prng.choice(2, p=[1 - p, p]): batch[embedder.input_key][i] = val return batch def forward( self, batch: Dict, force_zero_embeddings: Optional[List] = None ) -> Dict: output = dict() if force_zero_embeddings is None: force_zero_embeddings = [] for embedder in self.embedders: embedding_context = nullcontext if embedder.is_trainable else torch.no_grad with embedding_context(): if hasattr(embedder, "input_key") and (embedder.input_key is not None): if embedder.legacy_ucg_val is not None: batch = self.possibly_get_ucg_val(embedder, batch) emb_out = embedder(batch[embedder.input_key]) elif hasattr(embedder, "input_keys"): emb_out = embedder(*[batch[k] for k in embedder.input_keys]) assert isinstance( emb_out, (torch.Tensor, list, tuple) ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" if not isinstance(emb_out, (list, tuple)): emb_out = [emb_out] for emb in emb_out: out_key = self.OUTPUT_DIM2KEYS[emb.dim()] if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: emb = ( expand_dims_like( torch.bernoulli( (1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device) ), emb, ) * emb ) if ( hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings ): emb = torch.zeros_like(emb) if out_key in output: output[out_key] = torch.cat( (output[out_key], emb), self.KEY2CATDIM[out_key] ) else: output[out_key] = emb return output def get_unconditional_conditioning( self, batch_c: Dict, batch_uc: Optional[Dict] = None, force_uc_zero_embeddings: Optional[List[str]] = None, force_cond_zero_embeddings: Optional[List[str]] = None, ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] ucg_rates = list() for embedder in self.embedders: ucg_rates.append(embedder.ucg_rate) embedder.ucg_rate = 0.0 c = self(batch_c, force_cond_zero_embeddings) uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) for embedder, rate in zip(self.embedders, ucg_rates): embedder.ucg_rate = rate return c, uc class InceptionV3(nn.Module): """Wrapper around the https://github.com/mseitzer/pytorch-fid inception port with an additional squeeze at the end""" def __init__(self, normalize_input=False, **kwargs): super().__init__() from pytorch_fid import inception kwargs["resize_input"] = True self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs) def forward(self, inp): outp = self.model(inp) if len(outp) == 1: return outp[0].squeeze() return outp class IdentityEncoder(AbstractEmbModel): def encode(self, x): return x def forward(self, x): return x class ClassEmbedder(AbstractEmbModel): def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False): super().__init__() self.embedding = nn.Embedding(n_classes, embed_dim) self.n_classes = n_classes self.add_sequence_dim = add_sequence_dim def forward(self, c): c = self.embedding(c) if self.add_sequence_dim: c = c[:, None, :] return c def get_unconditional_conditioning(self, bs, device="cuda"): uc_class = ( self.n_classes - 1 ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) uc = torch.ones((bs,), device=device) * uc_class uc = {self.key: uc.long()} return uc class ClassEmbedderForMultiCond(ClassEmbedder): def forward(self, batch, key=None, disable_dropout=False): out = batch key = default(key, self.key) islist = isinstance(batch[key], list) if islist: batch[key] = batch[key][0] c_out = super().forward(batch, key, disable_dropout) out[key] = [c_out] if islist else c_out return out class FrozenT5Embedder(AbstractEmbModel): """Uses the T5 transformer encoder for text""" def __init__( self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device self.max_length = max_length if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) with torch.autocast("cuda", enabled=False): outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state return z def encode(self, text): return self(text) class FrozenByT5Embedder(AbstractEmbModel): """ Uses the ByT5 transformer encoder for text. Is character-aware. """ def __init__( self, version="google/byt5-base", device="cuda", max_length=77, freeze=True ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = ByT5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device self.max_length = max_length if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) with torch.autocast("cuda", enabled=False): outputs = self.transformer(input_ids=tokens) z = outputs.last_hidden_state return z def encode(self, text): return self(text) class FrozenCLIPEmbedder(AbstractEmbModel): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = ["last", "pooled", "hidden"] def __init__( self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, layer="last", layer_idx=None, always_return_pooled=False, ): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.device = device self.max_length = max_length if freeze: self.freeze() self.layer = layer self.layer_idx = layer_idx self.return_pooled = always_return_pooled if layer == "hidden": assert layer_idx is not None assert 0 <= abs(layer_idx) <= 12 def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False @autocast def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer( input_ids=tokens, output_hidden_states=self.layer == "hidden" ) if self.layer == "last": z = outputs.last_hidden_state elif self.layer == "pooled": z = outputs.pooler_output[:, None, :] else: z = outputs.hidden_states[self.layer_idx] if self.return_pooled: return z, outputs.pooler_output return z def encode(self, text): return self(text) class FrozenOpenCLIPEmbedder2(AbstractEmbModel): """ Uses the OpenCLIP transformer encoder for text """ LAYERS = ["pooled", "last", "penultimate"] def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last", always_return_pooled=False, legacy=True, ): super().__init__() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), pretrained=version, ) del model.visual self.model = model self.device = device self.max_length = max_length self.return_pooled = always_return_pooled if freeze: self.freeze() self.layer = layer if self.layer == "last": self.layer_idx = 0 elif self.layer == "penultimate": self.layer_idx = 1 else: raise NotImplementedError() self.legacy = legacy def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False @autocast def forward(self, text): tokens = open_clip.tokenize(text) z = self.encode_with_transformer(tokens.to(self.device)) if not self.return_pooled and self.legacy: return z if self.return_pooled: assert not self.legacy return z[self.layer], z["pooled"] return z[self.layer] def encode_with_transformer(self, text): x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) if self.legacy: x = x[self.layer] x = self.model.ln_final(x) return x else: # x is a dict and will stay a dict o = x["last"] o = self.model.ln_final(o) pooled = self.pool(o, text) x["pooled"] = pooled return x def pool(self, x, text): # take features from the eot embedding (eot_token is the highest number in each sequence) x = ( x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection ) return x def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): outputs = {} for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - 1: outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD if ( self.model.transformer.grad_checkpointing and not torch.jit.is_scripting() ): x = checkpoint(r, x, attn_mask) else: x = r(x, attn_mask=attn_mask) outputs["last"] = x.permute(1, 0, 2) # LND -> NLD return outputs def encode(self, text): return self(text) class FrozenOpenCLIPEmbedder(AbstractEmbModel): LAYERS = [ # "pooled", "last", "penultimate", ] def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last", ): super().__init__() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), pretrained=version ) del model.visual self.model = model self.device = device self.max_length = max_length if freeze: self.freeze() self.layer = layer if self.layer == "last": self.layer_idx = 0 elif self.layer == "penultimate": self.layer_idx = 1 else: raise NotImplementedError() def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): tokens = open_clip.tokenize(text) z = self.encode_with_transformer(tokens.to(self.device)) return z def encode_with_transformer(self, text): x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.model.ln_final(x) return x def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break if ( self.model.transformer.grad_checkpointing and not torch.jit.is_scripting() ): x = checkpoint(r, x, attn_mask) else: x = r(x, attn_mask=attn_mask) return x def encode(self, text): return self(text) class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): """ Uses the OpenCLIP vision transformer encoder for images """ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, antialias=True, ucg_rate=0.0, unsqueeze_dim=False, repeat_to_max_len=False, num_image_crops=0, output_tokens=False, init_device=None, ): super().__init__() model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device(default(init_device, "cpu")), pretrained=version, ) del model.transformer self.model = model self.max_crops = num_image_crops self.pad_to_max_len = self.max_crops > 0 self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) self.device = device self.max_length = max_length if freeze: self.freeze() self.antialias = antialias self.register_buffer( "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False ) self.register_buffer( "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False ) self.ucg_rate = ucg_rate self.unsqueeze_dim = unsqueeze_dim self.stored_batch = None self.model.visual.output_tokens = output_tokens self.output_tokens = output_tokens def preprocess(self, x): # normalize to [0,1] x = kornia.geometry.resize( x, (224, 224), interpolation="bicubic", align_corners=True, antialias=self.antialias, ) x = (x + 1.0) / 2.0 # renormalize according to clip x = kornia.enhance.normalize(x, self.mean, self.std) return x def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False @autocast def forward(self, image, no_dropout=False): z = self.encode_with_vision_transformer(image) tokens = None if self.output_tokens: z, tokens = z[0], z[1] z = z.to(image.dtype) if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): z = ( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) )[:, None] * z ) if tokens is not None: tokens = ( expand_dims_like( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(tokens.shape[0], device=tokens.device) ), tokens, ) * tokens ) if self.unsqueeze_dim: z = z[:, None, :] if self.output_tokens: assert not self.repeat_to_max_len assert not self.pad_to_max_len return tokens, z if self.repeat_to_max_len: if z.dim() == 2: z_ = z[:, None, :] else: z_ = z return repeat(z_, "b 1 d -> b n d", n=self.max_length), z elif self.pad_to_max_len: assert z.dim() == 3 z_pad = torch.cat( ( z, torch.zeros( z.shape[0], self.max_length - z.shape[1], z.shape[2], device=z.device, ), ), 1, ) return z_pad, z_pad[:, 0, ...] return z def encode_with_vision_transformer(self, img): # if self.max_crops > 0: # img = self.preprocess_by_cropping(img) if img.dim() == 5: assert self.max_crops == img.shape[1] img = rearrange(img, "b n c h w -> (b n) c h w") img = self.preprocess(img) if not self.output_tokens: assert not self.model.visual.output_tokens x = self.model.visual(img) tokens = None else: assert self.model.visual.output_tokens x, tokens = self.model.visual(img) if self.max_crops > 0: x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) # drop out between 0 and all along the sequence axis x = ( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) ) * x ) if tokens is not None: tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) print( f"You are running very experimental token-concat in {self.__class__.__name__}. " f"Check what you are doing, and then remove this message." ) if self.output_tokens: return x, tokens return x def encode(self, text): return self(text) class FrozenCLIPT5Encoder(AbstractEmbModel): def __init__( self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", clip_max_length=77, t5_max_length=77, ): super().__init__() self.clip_encoder = FrozenCLIPEmbedder( clip_version, device, max_length=clip_max_length ) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) print( f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." ) def encode(self, text): return self(text) def forward(self, text): clip_z = self.clip_encoder.encode(text) t5_z = self.t5_encoder.encode(text) return [clip_z, t5_z] class SpatialRescaler(nn.Module): def __init__( self, n_stages=1, method="bilinear", multiplier=0.5, in_channels=3, out_channels=None, bias=False, wrap_video=False, kernel_size=1, remap_output=False, ): super().__init__() self.n_stages = n_stages assert self.n_stages >= 0 assert method in [ "nearest", "linear", "bilinear", "trilinear", "bicubic", "area", ] self.multiplier = multiplier self.interpolator = partial(torch.nn.functional.interpolate, mode=method) self.remap_output = out_channels is not None or remap_output if self.remap_output: print( f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." ) self.channel_mapper = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, bias=bias, padding=kernel_size // 2, ) self.wrap_video = wrap_video def forward(self, x): if self.wrap_video and x.ndim == 5: B, C, T, H, W = x.shape x = rearrange(x, "b c t h w -> b t c h w") x = rearrange(x, "b t c h w -> (b t) c h w") for stage in range(self.n_stages): x = self.interpolator(x, scale_factor=self.multiplier) if self.wrap_video: x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C) x = rearrange(x, "b t c h w -> b c t h w") if self.remap_output: x = self.channel_mapper(x) return x def encode(self, x): return self(x) class LowScaleEncoder(nn.Module): def __init__( self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64, scale_factor=1.0, ): super().__init__() self.max_noise_level = max_noise_level self.model = instantiate_from_config(model_config) self.augmentation_schedule = self.register_schedule( timesteps=timesteps, linear_start=linear_start, linear_end=linear_end ) self.out_size = output_size self.scale_factor = scale_factor def register_schedule( self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, ): betas = make_beta_schedule( beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, ) alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert ( alphas_cumprod.shape[0] == self.num_timesteps ), "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer("betas", to_torch(betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) self.register_buffer( "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) ) self.register_buffer( "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) ) self.register_buffer( "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) ) self.register_buffer( "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def forward(self, x): z = self.model.encode(x) if isinstance(z, DiagonalGaussianDistribution): z = z.sample() z = z * self.scale_factor noise_level = torch.randint( 0, self.max_noise_level, (x.shape[0],), device=x.device ).long() z = self.q_sample(z, noise_level) if self.out_size is not None: z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") return z, noise_level def decode(self, z): z = z / self.scale_factor return self.model.decode(z) class ConcatTimestepEmbedderND(AbstractEmbModel): """embeds each dimension independently and concatenates them""" def __init__(self, outdim): super().__init__() self.timestep = Timestep(outdim) self.outdim = outdim def forward(self, x): if x.ndim == 1: x = x[:, None] assert len(x.shape) == 2 b, dims = x.shape[0], x.shape[1] x = rearrange(x, "b d -> (b d)") emb = self.timestep(x) emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) return emb class GaussianEncoder(Encoder, AbstractEmbModel): def __init__( self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs ): super().__init__(*args, **kwargs) self.posterior = DiagonalGaussianRegularizer() self.weight = weight self.flatten_output = flatten_output def forward(self, x) -> Tuple[Dict, torch.Tensor]: z = super().forward(x) z, log = self.posterior(z) log["loss"] = log["kl_loss"] log["weight"] = self.weight if self.flatten_output: z = rearrange(z, "b c h w -> b (h w ) c") return log, z class VideoPredictionEmbedderWithEncoder(AbstractEmbModel): def __init__( self, n_cond_frames: int, n_copies: int, encoder_config: dict, sigma_sampler_config: Optional[dict] = None, sigma_cond_config: Optional[dict] = None, is_ae: bool = False, scale_factor: float = 1.0, disable_encoder_autocast: bool = False, en_and_decode_n_samples_a_time: Optional[int] = None, ): super().__init__() self.n_cond_frames = n_cond_frames self.n_copies = n_copies self.encoder = instantiate_from_config(encoder_config) self.sigma_sampler = ( instantiate_from_config(sigma_sampler_config) if sigma_sampler_config is not None else None ) self.sigma_cond = ( instantiate_from_config(sigma_cond_config) if sigma_cond_config is not None else None ) self.is_ae = is_ae self.scale_factor = scale_factor self.disable_encoder_autocast = disable_encoder_autocast self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time def forward( self, vid: torch.Tensor ) -> Union[ torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, dict], Tuple[Tuple[torch.Tensor, torch.Tensor], dict], ]: if self.sigma_sampler is not None: b = vid.shape[0] // self.n_cond_frames sigmas = self.sigma_sampler(b).to(vid.device) if self.sigma_cond is not None: sigma_cond = self.sigma_cond(sigmas) sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies) sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames) noise = torch.randn_like(vid) vid = vid + noise * append_dims(sigmas, vid.ndim) with torch.autocast("cuda", enabled=not self.disable_encoder_autocast): n_samples = ( self.en_and_decode_n_samples_a_time if self.en_and_decode_n_samples_a_time is not None else vid.shape[0] ) n_rounds = math.ceil(vid.shape[0] / n_samples) all_out = [] for n in range(n_rounds): if self.is_ae: out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples]) else: out = self.encoder(vid[n * n_samples : (n + 1) * n_samples]) all_out.append(out) vid = torch.cat(all_out, dim=0) vid *= self.scale_factor vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames) vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies) return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid return return_val class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel): def __init__( self, open_clip_embedding_config: Dict, n_cond_frames: int, n_copies: int, ): super().__init__() self.n_cond_frames = n_cond_frames self.n_copies = n_copies self.open_clip = instantiate_from_config(open_clip_embedding_config) def forward(self, vid): vid = self.open_clip(vid) vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames) vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies) return vid