import math from typing import Any, Mapping import torch from torchvision.transforms.functional import to_pil_image import torch.nn as nn import kornia import open_clip from transformers import CLIPVisionModelWithProjection, AutoProcessor from transformers.models.bit.image_processing_bit import BitImageProcessor from einops import rearrange, repeat # FFN # from mamba_ssm import Mamba class ImgEmbContextResampler(nn.Module): def __init__( self, inner_dim=1280, cross_attention_dim=1024, expansion_factor=16, **kwargs, ): super().__init__() self.context_embedding = nn.Sequential( nn.Linear(cross_attention_dim, inner_dim), nn.SiLU(), nn.Linear(inner_dim, cross_attention_dim * expansion_factor), ) self.expansion_factor = expansion_factor self.cross_attention_dim = cross_attention_dim def forward(self, x, batch_size=0): if x.ndim == 2: x = rearrange(x, "(B F) C -> B F C", B=batch_size) assert x.ndim == 3 x = torch.mean(x, dim=1, keepdim=True) x = self.context_embedding(x) x = x.view(-1, self.expansion_factor, self.cross_attention_dim) return x class AbstractEncoder(nn.Module): def __init__(self): super().__init__() self.embedding_dim = -1 self.num_tokens = -1 def encode(self, *args, **kwargs): raise NotImplementedError class FrozenOpenCLIPImageEmbedder(AbstractEncoder): """ Uses the OpenCLIP vision transformer encoder for images """ def __init__( self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, antialias=True, ucg_rate=0.0, unsqueeze_dim=False, repeat_to_max_len=False, num_image_crops=0, output_tokens=False, ): super().__init__() model, _, _ = open_clip.create_model_and_transforms( arch, device=torch.device("cpu"), pretrained=version, ) del model.transformer self.model = model # self.model_t = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") # self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") self.max_crops = num_image_crops self.pad_to_max_len = self.max_crops > 0 self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) self.device = device self.max_length = max_length if freeze: self.freeze() self.antialias = antialias self.register_buffer( "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False ) self.register_buffer( "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False ) self.ucg_rate = ucg_rate self.unsqueeze_dim = unsqueeze_dim self.stored_batch = None # self.model.visual.output_tokens = output_tokens self.output_tokens = output_tokens def preprocess(self, x): # normalize to [0,1] x = kornia.geometry.resize( x, (224, 224), interpolation="bicubic", align_corners=True, antialias=self.antialias, ) x = (x + 1.0) / 2.0 # renormalize according to clip x = kornia.enhance.normalize(x, self.mean, self.std) return x def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False # self.model_t = self.model_t.eval() def forward(self, image, no_dropout=False): z = self.encode_with_vision_transformer(image) tokens = None if self.output_tokens: z, tokens = z[0], z[1] z = z.to(image.dtype) if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): z = ( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) )[:, None] * z ) if tokens is not None: tokens = ( expand_dims_like( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(tokens.shape[0], device=tokens.device) ), tokens, ) * tokens ) if self.unsqueeze_dim: z = z[:, None, :] if self.output_tokens: assert not self.repeat_to_max_len assert not self.pad_to_max_len return tokens, z if self.repeat_to_max_len: if z.dim() == 2: z_ = z[:, None, :] else: z_ = z return repeat(z_, "b 1 d -> b n d", n=self.max_length), z elif self.pad_to_max_len: assert z.dim() == 3 z_pad = torch.cat( ( z, torch.zeros( z.shape[0], self.max_length - z.shape[1], z.shape[2], device=z.device, ), ), 1, ) return z_pad, z_pad[:, 0, ...] return z def encode_with_vision_transformer(self, img): if self.max_crops > 0: img = self.preprocess_by_cropping(img) # pil_img = to_pil_image(img[0]*0.5 + 0.5) # inputs = self.processor(images=pil_img, return_tensors="pt").to("cuda") # outputs = self.model_t(**inputs) # return outputs.image_embeds if img.dim() == 5: assert self.max_crops == img.shape[1] img = rearrange(img, "b n c h w -> (b n) c h w") img = self.preprocess(img) if not self.output_tokens: assert not self.model.visual.output_tokens x = self.model.visual(img) tokens = None else: assert self.model.visual.output_tokens x, tokens = self.model.visual(img) if self.max_crops > 0: x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) # drop out between 0 and all along the sequence axis x = ( torch.bernoulli( (1.0 - self.ucg_rate) * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) ) * x ) if tokens is not None: tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) print( f"You are running very experimental token-concat in {self.__class__.__name__}. " f"Check what you are doing, and then remove this message." ) if self.output_tokens: return x, tokens return x def encode(self, text): return self(text)