NIRVANALAN
init
11e6f7b
raw
history blame
57.6 kB
import math
import random
import kiui
from kiui.op import recenter
import torchvision
import torchvision.transforms.v2
from contextlib import nullcontext
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
from pdb import set_trace as st
import kornia
import numpy as np
import open_clip
import torch
import torch.nn as nn
from einops import rearrange, repeat
from omegaconf import ListConfig
from torch.utils.checkpoint import checkpoint
from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer,
T5EncoderModel, T5Tokenizer)
from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
from ...modules.diffusionmodules.model import Encoder
from ...modules.diffusionmodules.openaimodel import Timestep
from ...modules.diffusionmodules.util import (extract_into_tensor,
make_beta_schedule)
from ...modules.distributions.distributions import DiagonalGaussianDistribution
from ...util import (append_dims, autocast, count_params, default,
disabled_train, expand_dims_like, instantiate_from_config)
from dit.dit_models_xformers import CaptionEmbedder, approx_gelu, t2i_modulate
class AbstractEmbModel(nn.Module):
def __init__(self):
super().__init__()
self._is_trainable = None
self._ucg_rate = None
self._input_key = None
@property
def is_trainable(self) -> bool:
return self._is_trainable
@property
def ucg_rate(self) -> Union[float, torch.Tensor]:
return self._ucg_rate
@property
def input_key(self) -> str:
return self._input_key
@is_trainable.setter
def is_trainable(self, value: bool):
self._is_trainable = value
@ucg_rate.setter
def ucg_rate(self, value: Union[float, torch.Tensor]):
self._ucg_rate = value
@input_key.setter
def input_key(self, value: str):
self._input_key = value
@is_trainable.deleter
def is_trainable(self):
del self._is_trainable
@ucg_rate.deleter
def ucg_rate(self):
del self._ucg_rate
@input_key.deleter
def input_key(self):
del self._input_key
class GeneralConditioner(nn.Module):
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
def __init__(self, emb_models: Union[List, ListConfig]):
super().__init__()
embedders = []
for n, embconfig in enumerate(emb_models):
embedder = instantiate_from_config(embconfig)
assert isinstance(
embedder, AbstractEmbModel
), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
embedder.is_trainable = embconfig.get("is_trainable", False)
embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
if not embedder.is_trainable:
embedder.train = disabled_train
for param in embedder.parameters():
param.requires_grad = False
embedder.eval()
print(
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
)
if "input_key" in embconfig:
embedder.input_key = embconfig["input_key"]
elif "input_keys" in embconfig:
embedder.input_keys = embconfig["input_keys"]
else:
raise KeyError(
f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
)
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
if embedder.legacy_ucg_val is not None:
embedder.ucg_prng = np.random.RandomState()
embedders.append(embedder)
self.embedders = nn.ModuleList(embedders)
def possibly_get_ucg_val(self, embedder: AbstractEmbModel,
batch: Dict) -> Dict:
assert embedder.legacy_ucg_val is not None
p = embedder.ucg_rate
val = embedder.legacy_ucg_val
for i in range(len(batch[embedder.input_key])):
if embedder.ucg_prng.choice(2, p=[1 - p, p]):
batch[embedder.input_key][i] = val
return batch
def forward(self,
batch: Dict,
force_zero_embeddings: Optional[List] = None) -> Dict:
output = dict()
if force_zero_embeddings is None:
force_zero_embeddings = []
for embedder in self.embedders:
embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
with embedding_context():
if hasattr(embedder, "input_key") and (embedder.input_key
is not None):
if embedder.legacy_ucg_val is not None:
batch = self.possibly_get_ucg_val(embedder, batch)
emb_out = embedder(batch[embedder.input_key])
elif hasattr(embedder, "input_keys"):
emb_out = embedder(
*[batch[k] for k in embedder.input_keys])
assert isinstance(
emb_out, (torch.Tensor, list, tuple)
), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
if not isinstance(emb_out, (list, tuple)):
emb_out = [emb_out]
for emb in emb_out:
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
emb = (expand_dims_like(
torch.bernoulli(
(1.0 - embedder.ucg_rate) *
torch.ones(emb.shape[0], device=emb.device)),
emb,
) * emb)
if (hasattr(embedder, "input_key")
and embedder.input_key in force_zero_embeddings):
emb = torch.zeros_like(emb)
if out_key in output:
output[out_key] = torch.cat((output[out_key], emb),
self.KEY2CATDIM[out_key])
else:
output[out_key] = emb
return output
def get_unconditional_conditioning(
self,
batch_c: Dict,
batch_uc: Optional[Dict] = None,
force_uc_zero_embeddings: Optional[List[str]] = None,
force_cond_zero_embeddings: Optional[List[str]] = None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
ucg_rates = list()
for embedder in self.embedders:
ucg_rates.append(embedder.ucg_rate)
embedder.ucg_rate = 0.0 # ! force no drop during inference
c = self(batch_c, force_cond_zero_embeddings)
uc = self(batch_c if batch_uc is None else batch_uc,
force_uc_zero_embeddings)
for embedder, rate in zip(self.embedders, ucg_rates):
embedder.ucg_rate = rate
return c, uc
class InceptionV3(nn.Module):
"""Wrapper around the https://github.com/mseitzer/pytorch-fid inception
port with an additional squeeze at the end"""
def __init__(self, normalize_input=False, **kwargs):
super().__init__()
from pytorch_fid import inception
kwargs["resize_input"] = True
self.model = inception.InceptionV3(normalize_input=normalize_input,
**kwargs)
def forward(self, inp):
outp = self.model(inp)
if len(outp) == 1:
return outp[0].squeeze()
return outp
class IdentityEncoder(AbstractEmbModel):
def encode(self, x):
return x
def forward(self, x):
return x
class ClassEmbedder(AbstractEmbModel):
def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
super().__init__()
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.add_sequence_dim = add_sequence_dim
def forward(self, c):
c = self.embedding(c)
if self.add_sequence_dim:
c = c[:, None, :]
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = (
self.n_classes - 1
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs, ), device=device) * uc_class
uc = {self.key: uc.long()}
return uc
class ClassEmbedderForMultiCond(ClassEmbedder):
def forward(self, batch, key=None, disable_dropout=False):
out = batch
key = default(key, self.key)
islist = isinstance(batch[key], list)
if islist:
batch[key] = batch[key][0]
c_out = super().forward(batch, key, disable_dropout)
out[key] = [c_out] if islist else c_out
return out
class FrozenT5Embedder(AbstractEmbModel):
"""Uses the T5 transformer encoder for text"""
def __init__(self,
version="google/t5-v1_1-xxl",
device="cuda",
max_length=77,
freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
with torch.autocast("cuda", enabled=False):
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenByT5Embedder(AbstractEmbModel):
"""
Uses the ByT5 transformer encoder for text. Is character-aware.
"""
def __init__(self,
version="google/byt5-base",
device="cuda",
max_length=77,
freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = ByT5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
with torch.autocast("cuda", enabled=False):
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEmbModel):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
layer="last",
layer_idx=None,
always_return_pooled=False,
): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
self.return_pooled = always_return_pooled
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens,
output_hidden_states=self.layer == "hidden")
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.return_pooled:
return z, outputs.pooler_output
return z
def encode(self, text):
return self(text)
class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = ["pooled", "last", "penultimate"]
def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="last",
always_return_pooled=False,
legacy=True,
):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device("cpu"),
pretrained=version,
)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
self.return_pooled = always_return_pooled
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
self.legacy = legacy
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
if not self.return_pooled and self.legacy:
return z
if self.return_pooled:
assert not self.legacy
return z[self.layer], z["pooled"]
return z[self.layer]
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
if self.legacy:
x = x[self.layer]
x = self.model.ln_final(x)
return x
else:
# x is a dict and will stay a dict
o = x["last"]
o = self.model.ln_final(o)
pooled = self.pool(o, text)
x["pooled"] = pooled
return x
def pool(self, x, text):
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = (x[torch.arange(x.shape[0]),
text.argmax(dim=-1)] @ self.model.text_projection)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
outputs = {}
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - 1:
outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
if (self.model.transformer.grad_checkpointing
and not torch.jit.is_scripting()):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
return outputs
def encode(self, text):
return self(text)
class FrozenOpenCLIPEmbedder(AbstractEmbModel):
LAYERS = [
# "pooled",
"last",
"penultimate",
]
def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="last",
):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device("cpu"), pretrained=version)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if (self.model.transformer.grad_checkpointing
and not torch.jit.is_scripting()):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(
self,
# arch="ViT-H-14",
# version="laion2b_s32b_b79k",
arch="ViT-L-14",
# version="laion2b_s32b_b82k",
version="openai",
device="cuda",
max_length=77,
freeze=True,
antialias=True,
ucg_rate=0.0,
unsqueeze_dim=False,
repeat_to_max_len=False,
num_image_crops=0,
output_tokens=False,
init_device=None,
):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device(default(init_device, "cpu")),
pretrained=version,
)
del model.transformer
self.model = model
self.max_crops = num_image_crops
self.pad_to_max_len = self.max_crops > 0
self.repeat_to_max_len = repeat_to_max_len and (
not self.pad_to_max_len)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.antialias = antialias
self.register_buffer("mean",
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer("std",
torch.Tensor([0.26862954, 0.26130258,
0.27577711]),
persistent=False)
self.ucg_rate = ucg_rate
self.unsqueeze_dim = unsqueeze_dim
self.stored_batch = None
self.model.visual.output_tokens = output_tokens
self.output_tokens = output_tokens
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
tokens = None
if self.output_tokens:
z, tokens = z[0], z[1]
z = z.to(image.dtype)
if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
z = (torch.bernoulli(
(1.0 - self.ucg_rate) *
torch.ones(z.shape[0], device=z.device))[:, None] * z)
if tokens is not None:
tokens = (expand_dims_like(
torch.bernoulli(
(1.0 - self.ucg_rate) *
torch.ones(tokens.shape[0], device=tokens.device)),
tokens,
) * tokens)
if self.unsqueeze_dim:
z = z[:, None, :]
if self.output_tokens:
assert not self.repeat_to_max_len
assert not self.pad_to_max_len
return tokens, z
if self.repeat_to_max_len:
if z.dim() == 2:
z_ = z[:, None, :]
else:
z_ = z
return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
elif self.pad_to_max_len:
assert z.dim() == 3
z_pad = torch.cat(
(
z,
torch.zeros(
z.shape[0],
self.max_length - z.shape[1],
z.shape[2],
device=z.device,
),
),
1,
)
return z_pad, z_pad[:, 0, ...]
return z
def encode_with_vision_transformer(self, img):
# if self.max_crops > 0:
# img = self.preprocess_by_cropping(img)
if img.dim() == 5:
assert self.max_crops == img.shape[1]
img = rearrange(img, "b n c h w -> (b n) c h w")
img = self.preprocess(img)
if not self.output_tokens:
assert not self.model.visual.output_tokens
x = self.model.visual(img)
tokens = None
else:
assert self.model.visual.output_tokens
x, tokens = self.model.visual(img)
if self.max_crops > 0:
x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
# drop out between 0 and all along the sequence axis
x = (torch.bernoulli(
(1.0 - self.ucg_rate) *
torch.ones(x.shape[0], x.shape[1], 1, device=x.device)) * x)
if tokens is not None:
tokens = rearrange(tokens,
"(b n) t d -> b t (n d)",
n=self.max_crops)
print(
f"You are running very experimental token-concat in {self.__class__.__name__}. "
f"Check what you are doing, and then remove this message.")
if self.output_tokens:
return x, tokens
return x
def encode(self, text):
return self(text)
# dino-v2 embedder
class FrozenDinov2ImageEmbedder(AbstractEmbModel):
"""
Uses the Dino-v2 for low-level image embedding
"""
def __init__(
self,
arch="vitl",
version="dinov2", # by default
device="cuda",
max_length=77,
freeze=True,
antialias=True,
ucg_rate=0.0,
unsqueeze_dim=False,
repeat_to_max_len=False,
num_image_crops=0,
output_tokens=False,
output_cls=False,
init_device=None,
):
super().__init__()
self.model = torch.hub.load(
f'facebookresearch/{version}',
'{}_{}{}_reg'.format(
version, f'{arch}', '14'
), # with registers better performance. vitl and vitg similar. Since fixed, load the best one.
pretrained=True).to(torch.device(default(init_device, "cpu")))
# ! frozen
# self.tokenizer.requires_grad_(False)
# self.tokenizer.eval()
# assert freeze # add adaLN here
if freeze:
self.freeze()
# self.model = model
self.max_crops = num_image_crops
self.pad_to_max_len = self.max_crops > 0
self.repeat_to_max_len = repeat_to_max_len and (
not self.pad_to_max_len)
self.device = device
self.max_length = max_length
self.antialias = antialias
# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/data/transforms.py#L41
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
self.register_buffer("mean",
torch.Tensor(IMAGENET_DEFAULT_MEAN),
persistent=False)
self.register_buffer("std",
torch.Tensor(IMAGENET_DEFAULT_STD),
persistent=False)
self.ucg_rate = ucg_rate
self.unsqueeze_dim = unsqueeze_dim
self.stored_batch = None
# self.model.visual.output_tokens = output_tokens
self.output_tokens = output_tokens # output
self.output_cls = output_cls
# self.output_tokens = False
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def _model_forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def encode_with_vision_transformer(self, img, **kwargs):
# if self.max_crops > 0:
# img = self.preprocess_by_cropping(img)
if img.dim() == 5:
# assert self.max_crops == img.shape[1]
img = rearrange(img, "b n c h w -> (b n) c h w")
img = self.preprocess(img)
# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L326
if not self.output_cls:
return self._model_forward(
img, is_training=True,
**kwargs)['x_norm_patchtokens'] # to return spatial tokens
else:
dino_ret_dict = self._model_forward(
img, is_training=True) # to return spatial tokens
x_patchtokens, x_norm_clstoken = dino_ret_dict[
'x_norm_patchtokens'], dino_ret_dict['x_norm_clstoken']
return x_norm_clstoken, x_patchtokens
@autocast
def forward(self, image, no_dropout=False, **kwargs):
tokens = self.encode_with_vision_transformer(image, **kwargs)
z = None
if self.output_cls:
z, tokens = z[0], z[1]
z = z.to(image.dtype)
tokens = tokens.to(image.dtype) # ! return spatial tokens only
if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
if z is not None:
z = (torch.bernoulli(
(1.0 - self.ucg_rate) *
torch.ones(z.shape[0], device=z.device))[:, None] * z)
tokens = (expand_dims_like(
torch.bernoulli(
(1.0 - self.ucg_rate) *
torch.ones(tokens.shape[0], device=tokens.device)),
tokens,
) * tokens)
if self.output_cls:
return tokens, z
else:
return tokens
class FrozenDinov2ImageEmbedderMVPlucker(FrozenDinov2ImageEmbedder):
def __init__(
self,
arch="vitl",
version="dinov2", # by default
device="cuda",
max_length=77,
freeze=True,
antialias=True,
ucg_rate=0.0,
unsqueeze_dim=False,
repeat_to_max_len=False,
num_image_crops=0,
output_tokens=False,
output_cls=False,
init_device=None,
# mv cond settings
n_cond_frames=4, # numebr of condition views
enable_bf16=False,
modLN=False,
aug_c=False,
):
super().__init__(
arch,
version,
device,
max_length,
freeze,
antialias,
ucg_rate,
unsqueeze_dim,
repeat_to_max_len,
num_image_crops,
output_tokens,
output_cls,
init_device,
)
self.n_cond_frames = n_cond_frames
self.dtype = torch.bfloat16 if enable_bf16 else torch.float32
self.enable_bf16 = enable_bf16
self.aug_c = aug_c
# ! proj c_cond to features
self.reso_encoder = 224
orig_patch_embed_weight = self.model.patch_embed.state_dict()
# ! 9-d input
with torch.no_grad():
new_patch_embed = PatchEmbed(img_size=224,
patch_size=14,
in_chans=9,
embed_dim=self.model.embed_dim)
# zero init first
nn.init.constant_(new_patch_embed.proj.weight, 0)
nn.init.constant_(new_patch_embed.proj.bias, 0)
# load pre-trained first 3 layers weights, bias into the new patch_embed
new_patch_embed.proj.weight[:, :3].copy_(orig_patch_embed_weight['proj.weight'])
new_patch_embed.proj.bias[:].copy_(orig_patch_embed_weight['proj.bias'])
self.model.patch_embed = new_patch_embed # xyz in the front
# self.scale_jitter_aug = torchvision.transforms.v2.ScaleJitter(target_size=(self.reso_encoder, self.reso_encoder), scale_range=(0.5, 1.5))
@autocast
def scale_jitter_aug(self, x):
inp_size = x.shape[2]
# aug_size = torch.randint(low=50, high=100, size=(1,)) / 100 * inp_size
aug_size = int(max(0.5, random.random()) * inp_size)
# st()
x = torch.nn.functional.interpolate(x,
size=aug_size,
mode='bilinear',
antialias=True)
x = torch.nn.functional.interpolate(x,size=inp_size,
mode='bilinear', antialias=True)
return x
@autocast
def gen_rays(self, c):
# Generate rays
intrinsics, c2w = c[16:], c[:16].reshape(4, 4)
self.h = self.reso_encoder
self.w = self.reso_encoder
yy, xx = torch.meshgrid(
torch.arange(self.h, dtype=torch.float32, device=c.device) + 0.5,
torch.arange(self.w, dtype=torch.float32, device=c.device) + 0.5,
indexing='ij')
# normalize to 0-1 pixel range
yy = yy / self.h
xx = xx / self.w
# K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[
0], intrinsics[4]
# cx *= self.w
# cy *= self.h
# f_x = f_y = fx * h / res_raw
# c2w = torch.from_numpy(c2w).float()
c2w = c2w.float()
xx = (xx - cx) / fx
yy = (yy - cy) / fy
zz = torch.ones_like(xx)
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
dirs /= torch.norm(dirs, dim=-1, keepdim=True)
dirs = dirs.reshape(-1, 3, 1)
del xx, yy, zz
# st()
dirs = (c2w[None, :3, :3] @ dirs)[..., 0]
origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous()
origins = origins.view(self.h, self.w, 3)
dirs = dirs.view(self.h, self.w, 3)
return origins, dirs
@autocast
def get_plucker_ray(self, c):
rays_plucker = []
for idx in range(c.shape[0]):
rays_o, rays_d = self.gen_rays(c[idx])
rays_plucker.append(
torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d],
dim=-1).permute(2, 0, 1)) # [h, w, 6] -> 6,h,w
rays_plucker = torch.stack(rays_plucker, 0)
return rays_plucker
@autocast
def _model_forward(self, x, plucker_c, *args, **kwargs):
with torch.cuda.amp.autocast(dtype=self.dtype, enabled=True):
x = torch.cat([x, plucker_c], dim=1).to(self.dtype)
return self.model(x, **kwargs)
def preprocess(self, x):
# add gaussian noise and rescale augmentation
if self.ucg_rate > 0.0:
# 1 means maintain the input x
enable_drop_flag = torch.bernoulli(
(1.0 - self.ucg_rate) *
torch.ones(x.shape[0], device=x.device))[:, None, None, None] # broadcast to B,1,1,1
# * add random downsample & upsample
# rescaled_x = self.downsample_upsample(x)
# torchvision.utils.save_image(x, 'tmp/x.png', normalize=True, value_range=(-1,1))
x_aug = self.scale_jitter_aug(x)
# torchvision.utils.save_image(x_aug, 'tmp/rescale-x.png', normalize=True, value_range=(-1,1))
# x_aug = x * enable_drop_flag + (1-enable_drop_flag) * x_aug
# * guassian noise jitter
# force linear_weight > 0.24
# linear_weight = torch.max(enable_drop_flag, torch.max(torch.rand_like(enable_drop_flag), 0.25 * torch.ones_like(enable_drop_flag), dim=0, keepdim=True), dim=0, keepdim=True)
gaussian_jitter_scale, jitter_lb = torch.rand_like(enable_drop_flag), 0.5 * torch.ones_like(enable_drop_flag)
gaussian_jitter_scale = torch.where(gaussian_jitter_scale>jitter_lb, gaussian_jitter_scale, jitter_lb)
# torchvision.utils.save_image(x, 'tmp/aug-x.png', normalize=True, value_range=(-1,1))
x_aug = gaussian_jitter_scale * x_aug + (1 - gaussian_jitter_scale) * torch.randn_like(x).clamp(-1,1)
x_aug = x * enable_drop_flag + (1-enable_drop_flag) * x_aug
# torchvision.utils.save_image(x_aug, 'tmp/final-x.png', normalize=True, value_range=(-1,1))
# st()
return super().preprocess(x)
def random_rotate_c(self, c):
intrinsics, c2ws = c[16:], c[:16].reshape(4, 4)
# https://github.com/TencentARC/InstantMesh/blob/34c193cc96eebd46deb7c48a76613753ad777122/src/data/objaverse.py#L195
degree = np.random.uniform(0, math.pi * 2)
# random rotation along z axis
if random.random() > 0.5:
rot = torch.tensor([
[np.cos(degree), -np.sin(degree), 0, 0],
[np.sin(degree), np.cos(degree), 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]).to(c2ws)
else:
# random rotation along y axis
rot = torch.tensor([
[np.cos(degree), 0, np.sin(degree), 0],
[0, 1, 0, 0],
[-np.sin(degree), 0, np.cos(degree), 0],
[0, 0, 0, 1],
]).to(c2ws)
c2ws = torch.matmul(rot, c2ws)
return torch.cat([c2ws.reshape(-1), intrinsics])
@autocast
def forward(self, img_c, no_dropout=False):
mv_image, c = img_c['img'], img_c['c']
if self.aug_c:
for idx_b in range(c.shape[0]):
for idx_v in range(c.shape[1]):
if random.random() > 0.6:
c[idx_b, idx_v] = self.random_rotate_c(c[idx_b, idx_v])
# plucker_c = self.get_plucker_ray(
# rearrange(c[:, 1:1 + self.n_cond_frames], "b t ... -> (b t) ..."))
plucker_c = self.get_plucker_ray(
rearrange(c[:, :self.n_cond_frames], "b t ... -> (b t) ..."))
# mv_image_tokens = super().forward(mv_image[:, 1:1 + self.n_cond_frames],
mv_image_tokens = super().forward(mv_image[:, :self.n_cond_frames],
plucker_c=plucker_c,
no_dropout=no_dropout)
mv_image_tokens = rearrange(mv_image_tokens,
"(b t) ... -> b t ...",
t=self.n_cond_frames)
return mv_image_tokens
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (
image_HW[0] // patch_HW[0],
image_HW[1] // patch_HW[1],
)
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans,
embed_dim,
kernel_size=patch_HW,
stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (
self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class FrozenDinov2ImageEmbedderMV(FrozenDinov2ImageEmbedder):
def __init__(
self,
arch="vitl",
version="dinov2", # by default
device="cuda",
max_length=77,
freeze=True,
antialias=True,
ucg_rate=0.0,
unsqueeze_dim=False,
repeat_to_max_len=False,
num_image_crops=0,
output_tokens=False,
output_cls=False,
init_device=None,
# mv cond settings
n_cond_frames=4, # numebr of condition views
enable_bf16=False,
modLN=False,
):
super().__init__(
arch,
version,
device,
max_length,
freeze,
antialias,
ucg_rate,
unsqueeze_dim,
repeat_to_max_len,
num_image_crops,
output_tokens,
output_cls,
init_device,
)
self.n_cond_frames = n_cond_frames
self.dtype = torch.bfloat16 if enable_bf16 else torch.float32
self.enable_bf16 = enable_bf16
# ! proj c_cond to features
hidden_size = self.model.embed_dim # 768 for vit-b
# self.cam_proj = CaptionEmbedder(16, hidden_size,
self.cam_proj = CaptionEmbedder(25, hidden_size, act_layer=approx_gelu)
# ! single-modLN
self.model.modLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 4 * hidden_size, bias=True))
# zero-init modLN
nn.init.constant_(self.model.modLN_modulation[-1].weight, 0)
nn.init.constant_(self.model.modLN_modulation[-1].bias, 0)
# inject modLN to dino block
for block in self.model.blocks:
block.scale_shift_table = nn.Parameter(torch.zeros(
4, hidden_size)) # zero init also
# torch.randn(4, hidden_size) / hidden_size**0.5)
def _model_forward(self, x, *args, **kwargs):
# re-define model forward, finetune dino-v2.
assert self.training
# ? how to send in camera
# c = 0 # placeholder
# ret = self.model.forward_features(*args, **kwargs)
with torch.cuda.amp.autocast(dtype=self.dtype, enabled=True):
x = self.model.prepare_tokens_with_masks(x, masks=None)
B, N, C = x.shape
# TODO how to send in c
# c = torch.ones(B, 25).to(x) # placeholder
c = kwargs.get('c')
c = self.cam_proj(c)
cond = self.model.modLN_modulation(c)
# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/block.py#L89
for blk in self.model.blocks: # inject modLN
shift_msa, scale_msa, shift_mlp, scale_mlp = (
blk.scale_shift_table[None] +
cond.reshape(B, 4, -1)).chunk(4, dim=1)
def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
# return blk.ls1(blk.attn(blk.norm1(x), attn_bias=attn_bias))
return blk.ls1(
blk.attn(
t2i_modulate(blk.norm1(x), shift_msa, scale_msa)))
def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
# return blk.ls2(blk.mlp(blk.norm2(x)))
return blk.ls2(
t2i_modulate(blk.mlp(blk.norm2(x)), shift_mlp,
scale_mlp))
x = x + blk.drop_path1(
attn_residual_func(x)) # all drop_path identity() here.
x = x + blk.drop_path2(ffn_residual_func(x))
x_norm = self.model.norm(x)
return {
"x_norm_clstoken": x_norm[:, 0],
# "x_norm_regtokens": x_norm[:, 1 : self.model.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:,
self.model.num_register_tokens + 1:],
# "x_prenorm": x,
# "masks": masks,
}
@autocast
def forward(self, img_c, no_dropout=False):
# if self.enable_bf16:
# with th.cuda.amp.autocast(dtype=self.dtype,
# enabled=True):
# mv_image = super().forward(mv_image[:, 1:1+self.n_cond_frames].to(torch.bf16))
# else:
mv_image, c = img_c['img'], img_c['c']
# ! use zero c here, ablation. current verison wrong.
# c = torch.zeros_like(c)
# ! frame-0 as canonical here.
mv_image = super().forward(mv_image[:, 1:1 + self.n_cond_frames],
c=rearrange(c[:, 1:1 + self.n_cond_frames],
"b t ... -> (b t) ...",
t=self.n_cond_frames),
no_dropout=no_dropout)
mv_image = rearrange(mv_image,
"(b t) ... -> b t ...",
t=self.n_cond_frames)
return mv_image
class FrozenCLIPT5Encoder(AbstractEmbModel):
def __init__(
self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
clip_max_length=77,
t5_max_length=77,
):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version,
device,
max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(t5_version,
device,
max_length=t5_max_length)
print(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
)
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
class SpatialRescaler(nn.Module):
def __init__(
self,
n_stages=1,
method="bilinear",
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False,
wrap_video=False,
kernel_size=1,
remap_output=False,
):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in [
"nearest",
"linear",
"bilinear",
"trilinear",
"bicubic",
"area",
]
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate,
mode=method)
self.remap_output = out_channels is not None or remap_output
if self.remap_output:
print(
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
)
self.channel_mapper = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
bias=bias,
padding=kernel_size // 2,
)
self.wrap_video = wrap_video
def forward(self, x):
if self.wrap_video and x.ndim == 5:
B, C, T, H, W = x.shape
x = rearrange(x, "b c t h w -> b t c h w")
x = rearrange(x, "b t c h w -> (b t) c h w")
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.wrap_video:
x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
x = rearrange(x, "b t c h w -> b c t h w")
if self.remap_output:
x = self.channel_mapper(x)
return x
def encode(self, x):
return self(x)
class LowScaleEncoder(nn.Module):
def __init__(
self,
model_config,
linear_start,
linear_end,
timesteps=1000,
max_noise_level=250,
output_size=64,
scale_factor=1.0,
):
super().__init__()
self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(
timesteps=timesteps,
linear_start=linear_start,
linear_end=linear_end)
self.out_size = output_size
self.scale_factor = scale_factor
def register_schedule(
self,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps, ) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev",
to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod",
to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer("sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod)))
self.register_buffer("log_one_minus_alphas_cumprod",
to_torch(np.log(1.0 - alphas_cumprod)))
self.register_buffer("sqrt_recip_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod)))
self.register_buffer("sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
t, x_start.shape) * noise)
def forward(self, x):
z = self.model.encode(x)
if isinstance(z, DiagonalGaussianDistribution):
z = z.sample()
z = z * self.scale_factor
noise_level = torch.randint(0,
self.max_noise_level, (x.shape[0], ),
device=x.device).long()
z = self.q_sample(z, noise_level)
if self.out_size is not None:
z = torch.nn.functional.interpolate(z,
size=self.out_size,
mode="nearest")
return z, noise_level
def decode(self, z):
z = z / self.scale_factor
return self.model.decode(z)
class ConcatTimestepEmbedderND(AbstractEmbModel):
"""embeds each dimension independently and concatenates them"""
def __init__(self, outdim):
super().__init__()
self.timestep = Timestep(outdim)
self.outdim = outdim
def forward(self, x):
if x.ndim == 1:
x = x[:, None]
assert len(x.shape) == 2
b, dims = x.shape[0], x.shape[1]
x = rearrange(x, "b d -> (b d)")
emb = self.timestep(x)
emb = rearrange(emb,
"(b d) d2 -> b (d d2)",
b=b,
d=dims,
d2=self.outdim)
return emb
class GaussianEncoder(Encoder, AbstractEmbModel):
def __init__(self,
weight: float = 1.0,
flatten_output: bool = True,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.posterior = DiagonalGaussianRegularizer()
self.weight = weight
self.flatten_output = flatten_output
def forward(self, x) -> Tuple[Dict, torch.Tensor]:
z = super().forward(x)
z, log = self.posterior(z)
log["loss"] = log["kl_loss"]
log["weight"] = self.weight
if self.flatten_output:
z = rearrange(z, "b c h w -> b (h w ) c")
return log, z
class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
def __init__(
self,
n_cond_frames: int,
n_copies: int,
encoder_config: dict,
sigma_sampler_config: Optional[dict] = None,
sigma_cond_config: Optional[dict] = None,
is_ae: bool = False,
scale_factor: float = 1.0,
disable_encoder_autocast: bool = False,
en_and_decode_n_samples_a_time: Optional[int] = None,
):
super().__init__()
self.n_cond_frames = n_cond_frames
self.n_copies = n_copies
self.encoder = instantiate_from_config(encoder_config)
self.sigma_sampler = (instantiate_from_config(sigma_sampler_config)
if sigma_sampler_config is not None else None)
self.sigma_cond = (instantiate_from_config(sigma_cond_config)
if sigma_cond_config is not None else None)
self.is_ae = is_ae
self.scale_factor = scale_factor
self.disable_encoder_autocast = disable_encoder_autocast
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
def forward(
self, vid: torch.Tensor
) -> Union[
torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, dict],
Tuple[Tuple[torch.Tensor, torch.Tensor], dict],
]:
if self.sigma_sampler is not None:
b = vid.shape[0] // self.n_cond_frames
sigmas = self.sigma_sampler(b).to(vid.device)
if self.sigma_cond is not None:
sigma_cond = self.sigma_cond(sigmas)
sigma_cond = repeat(sigma_cond,
"b d -> (b t) d",
t=self.n_copies)
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
noise = torch.randn_like(vid)
vid = vid + noise * append_dims(sigmas, vid.ndim)
with torch.autocast("cuda", enabled=not self.disable_encoder_autocast):
n_samples = (self.en_and_decode_n_samples_a_time
if self.en_and_decode_n_samples_a_time is not None
else vid.shape[0])
n_rounds = math.ceil(vid.shape[0] / n_samples)
all_out = []
for n in range(n_rounds):
if self.is_ae:
out = self.encoder.encode(vid[n * n_samples:(n + 1) *
n_samples])
else:
out = self.encoder(vid[n * n_samples:(n + 1) * n_samples])
all_out.append(out)
vid = torch.cat(all_out, dim=0)
vid *= self.scale_factor
vid = rearrange(vid,
"(b t) c h w -> b () (t c) h w",
t=self.n_cond_frames)
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
return return_val
class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):
def __init__(
self,
open_clip_embedding_config: Dict,
n_cond_frames: int,
n_copies: int,
):
super().__init__()
self.n_cond_frames = n_cond_frames
self.n_copies = n_copies
self.open_clip = instantiate_from_config(open_clip_embedding_config)
def forward(self, vid):
vid = self.open_clip(vid)
vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
return vid
class FrozenOpenCLIPImageMVEmbedder(AbstractEmbModel):
# for multi-view 3D diffusion condition. Only extract the first frame
def __init__(
self,
open_clip_embedding_config: Dict,
# n_cond_frames: int,
# n_copies: int,
):
super().__init__()
# self.n_cond_frames = n_cond_frames
# self.n_copies = n_copies
self.open_clip = instantiate_from_config(open_clip_embedding_config)
def forward(self, vid, no_dropout=False):
# st()
vid = self.open_clip(vid[:, 0, ...], no_dropout=no_dropout)
# vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
# vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
return vid