hpoghos's picture
add open_clip
c048c54
raw
history blame
7.19 kB
import math
from typing import Any, Mapping
import torch
from torchvision.transforms.functional import to_pil_image
import torch.nn as nn
import kornia
import open_clip
from transformers import CLIPVisionModelWithProjection, AutoProcessor
from transformers.models.bit.image_processing_bit import BitImageProcessor
from einops import rearrange, repeat
# FFN
# from mamba_ssm import Mamba
class ImgEmbContextResampler(nn.Module):
def __init__(
self,
inner_dim=1280,
cross_attention_dim=1024,
expansion_factor=16,
**kwargs,
):
super().__init__()
self.context_embedding = nn.Sequential(
nn.Linear(cross_attention_dim, inner_dim),
nn.SiLU(),
nn.Linear(inner_dim, cross_attention_dim * expansion_factor),
)
self.expansion_factor = expansion_factor
self.cross_attention_dim = cross_attention_dim
def forward(self, x, batch_size=0):
if x.ndim == 2:
x = rearrange(x, "(B F) C -> B F C", B=batch_size)
assert x.ndim == 3
x = torch.mean(x, dim=1, keepdim=True)
x = self.context_embedding(x)
x = x.view(-1, self.expansion_factor, self.cross_attention_dim)
return x
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
self.embedding_dim = -1
self.num_tokens = -1
def encode(self, *args, **kwargs):
raise NotImplementedError
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
antialias=True,
ucg_rate=0.0,
unsqueeze_dim=False,
repeat_to_max_len=False,
num_image_crops=0,
output_tokens=False,
):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device("cpu"),
pretrained=version,
)
del model.transformer
self.model = model
# self.model_t = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
# self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
self.max_crops = num_image_crops
self.pad_to_max_len = self.max_crops > 0
self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.antialias = antialias
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
self.ucg_rate = ucg_rate
self.unsqueeze_dim = unsqueeze_dim
self.stored_batch = None
# self.model.visual.output_tokens = output_tokens
self.output_tokens = output_tokens
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
# self.model_t = self.model_t.eval()
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
tokens = None
if self.output_tokens:
z, tokens = z[0], z[1]
z = z.to(image.dtype)
if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
z = (
torch.bernoulli(
(1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
)[:, None]
* z
)
if tokens is not None:
tokens = (
expand_dims_like(
torch.bernoulli(
(1.0 - self.ucg_rate)
* torch.ones(tokens.shape[0], device=tokens.device)
),
tokens,
)
* tokens
)
if self.unsqueeze_dim:
z = z[:, None, :]
if self.output_tokens:
assert not self.repeat_to_max_len
assert not self.pad_to_max_len
return tokens, z
if self.repeat_to_max_len:
if z.dim() == 2:
z_ = z[:, None, :]
else:
z_ = z
return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
elif self.pad_to_max_len:
assert z.dim() == 3
z_pad = torch.cat(
(
z,
torch.zeros(
z.shape[0],
self.max_length - z.shape[1],
z.shape[2],
device=z.device,
),
),
1,
)
return z_pad, z_pad[:, 0, ...]
return z
def encode_with_vision_transformer(self, img):
if self.max_crops > 0:
img = self.preprocess_by_cropping(img)
# pil_img = to_pil_image(img[0]*0.5 + 0.5)
# inputs = self.processor(images=pil_img, return_tensors="pt").to("cuda")
# outputs = self.model_t(**inputs)
# return outputs.image_embeds
if img.dim() == 5:
assert self.max_crops == img.shape[1]
img = rearrange(img, "b n c h w -> (b n) c h w")
img = self.preprocess(img)
if not self.output_tokens:
assert not self.model.visual.output_tokens
x = self.model.visual(img)
tokens = None
else:
assert self.model.visual.output_tokens
x, tokens = self.model.visual(img)
if self.max_crops > 0:
x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
# drop out between 0 and all along the sequence axis
x = (
torch.bernoulli(
(1.0 - self.ucg_rate)
* torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
)
* x
)
if tokens is not None:
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
print(
f"You are running very experimental token-concat in {self.__class__.__name__}. "
f"Check what you are doing, and then remove this message."
)
if self.output_tokens:
return x, tokens
return x
def encode(self, text):
return self(text)