Spaces:
Runtime error
Runtime error
import math | |
from typing import Any, Mapping | |
import torch | |
from torchvision.transforms.functional import to_pil_image | |
import torch.nn as nn | |
import kornia | |
import open_clip | |
from transformers import CLIPVisionModelWithProjection, AutoProcessor | |
from transformers.models.bit.image_processing_bit import BitImageProcessor | |
from einops import rearrange, repeat | |
# FFN | |
# from mamba_ssm import Mamba | |
class ImgEmbContextResampler(nn.Module): | |
def __init__( | |
self, | |
inner_dim=1280, | |
cross_attention_dim=1024, | |
expansion_factor=16, | |
**kwargs, | |
): | |
super().__init__() | |
self.context_embedding = nn.Sequential( | |
nn.Linear(cross_attention_dim, inner_dim), | |
nn.SiLU(), | |
nn.Linear(inner_dim, cross_attention_dim * expansion_factor), | |
) | |
self.expansion_factor = expansion_factor | |
self.cross_attention_dim = cross_attention_dim | |
def forward(self, x, batch_size=0): | |
if x.ndim == 2: | |
x = rearrange(x, "(B F) C -> B F C", B=batch_size) | |
assert x.ndim == 3 | |
x = torch.mean(x, dim=1, keepdim=True) | |
x = self.context_embedding(x) | |
x = x.view(-1, self.expansion_factor, self.cross_attention_dim) | |
return x | |
class AbstractEncoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.embedding_dim = -1 | |
self.num_tokens = -1 | |
def encode(self, *args, **kwargs): | |
raise NotImplementedError | |
class FrozenOpenCLIPImageEmbedder(AbstractEncoder): | |
""" | |
Uses the OpenCLIP vision transformer encoder for images | |
""" | |
def __init__( | |
self, | |
arch="ViT-H-14", | |
version="laion2b_s32b_b79k", | |
device="cuda", | |
max_length=77, | |
freeze=True, | |
antialias=True, | |
ucg_rate=0.0, | |
unsqueeze_dim=False, | |
repeat_to_max_len=False, | |
num_image_crops=0, | |
output_tokens=False, | |
): | |
super().__init__() | |
model, _, _ = open_clip.create_model_and_transforms( | |
arch, | |
device=torch.device("cpu"), | |
pretrained=version, | |
) | |
del model.transformer | |
self.model = model | |
# self.model_t = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") | |
# self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") | |
self.max_crops = num_image_crops | |
self.pad_to_max_len = self.max_crops > 0 | |
self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
self.antialias = antialias | |
self.register_buffer( | |
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False | |
) | |
self.register_buffer( | |
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False | |
) | |
self.ucg_rate = ucg_rate | |
self.unsqueeze_dim = unsqueeze_dim | |
self.stored_batch = None | |
# self.model.visual.output_tokens = output_tokens | |
self.output_tokens = output_tokens | |
def preprocess(self, x): | |
# normalize to [0,1] | |
x = kornia.geometry.resize( | |
x, | |
(224, 224), | |
interpolation="bicubic", | |
align_corners=True, | |
antialias=self.antialias, | |
) | |
x = (x + 1.0) / 2.0 | |
# renormalize according to clip | |
x = kornia.enhance.normalize(x, self.mean, self.std) | |
return x | |
def freeze(self): | |
self.model = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
# self.model_t = self.model_t.eval() | |
def forward(self, image, no_dropout=False): | |
z = self.encode_with_vision_transformer(image) | |
tokens = None | |
if self.output_tokens: | |
z, tokens = z[0], z[1] | |
z = z.to(image.dtype) | |
if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): | |
z = ( | |
torch.bernoulli( | |
(1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) | |
)[:, None] | |
* z | |
) | |
if tokens is not None: | |
tokens = ( | |
expand_dims_like( | |
torch.bernoulli( | |
(1.0 - self.ucg_rate) | |
* torch.ones(tokens.shape[0], device=tokens.device) | |
), | |
tokens, | |
) | |
* tokens | |
) | |
if self.unsqueeze_dim: | |
z = z[:, None, :] | |
if self.output_tokens: | |
assert not self.repeat_to_max_len | |
assert not self.pad_to_max_len | |
return tokens, z | |
if self.repeat_to_max_len: | |
if z.dim() == 2: | |
z_ = z[:, None, :] | |
else: | |
z_ = z | |
return repeat(z_, "b 1 d -> b n d", n=self.max_length), z | |
elif self.pad_to_max_len: | |
assert z.dim() == 3 | |
z_pad = torch.cat( | |
( | |
z, | |
torch.zeros( | |
z.shape[0], | |
self.max_length - z.shape[1], | |
z.shape[2], | |
device=z.device, | |
), | |
), | |
1, | |
) | |
return z_pad, z_pad[:, 0, ...] | |
return z | |
def encode_with_vision_transformer(self, img): | |
if self.max_crops > 0: | |
img = self.preprocess_by_cropping(img) | |
# pil_img = to_pil_image(img[0]*0.5 + 0.5) | |
# inputs = self.processor(images=pil_img, return_tensors="pt").to("cuda") | |
# outputs = self.model_t(**inputs) | |
# return outputs.image_embeds | |
if img.dim() == 5: | |
assert self.max_crops == img.shape[1] | |
img = rearrange(img, "b n c h w -> (b n) c h w") | |
img = self.preprocess(img) | |
if not self.output_tokens: | |
assert not self.model.visual.output_tokens | |
x = self.model.visual(img) | |
tokens = None | |
else: | |
assert self.model.visual.output_tokens | |
x, tokens = self.model.visual(img) | |
if self.max_crops > 0: | |
x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) | |
# drop out between 0 and all along the sequence axis | |
x = ( | |
torch.bernoulli( | |
(1.0 - self.ucg_rate) | |
* torch.ones(x.shape[0], x.shape[1], 1, device=x.device) | |
) | |
* x | |
) | |
if tokens is not None: | |
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) | |
print( | |
f"You are running very experimental token-concat in {self.__class__.__name__}. " | |
f"Check what you are doing, and then remove this message." | |
) | |
if self.output_tokens: | |
return x, tokens | |
return x | |
def encode(self, text): | |
return self(text) |