File size: 6,507 Bytes
7ff77f3 cd16641 7ff77f3 cd16641 7ff77f3 cd16641 7ff77f3 cd16641 7ff77f3 cd16641 7ff77f3 cd16641 7ff77f3 cd16641 7ff77f3 cd16641 7ff77f3 cd16641 7ff77f3 cd16641 7ff77f3 cd16641 7ff77f3 cd16641 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
from typing import Callable, Optional
import numpy as np
import torch
from transformers import CLIPModel, CLIPProcessor
from transformers.modeling_utils import PreTrainedModel
from .configuration_cased import CaSEDConfig
from .retrieval_cased import RetrievalDatabase, download_retrieval_databases
from .transforms_cased import default_vocabulary_transforms
class CaSEDModel(PreTrainedModel):
"""Transformers module for Category Search from External Databases (CaSED).
Reference:
- Conti et al. Vocabulary-free Image Classification. NeurIPS 2023.
Args:
config (CaSEDConfig): Configuration class for CaSED.
"""
config_class = CaSEDConfig
def __init__(self, config: CaSEDConfig):
super().__init__(config)
# load CLIP
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.vision_encoder = model.vision_model
self.vision_proj = model.visual_projection
self.language_encoder = model.text_model
self.language_proj = model.text_projection
self.logit_scale = model.logit_scale.exp()
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
# set hparams
self.hparams = {}
self.hparams["alpha"] = config.alpha
self.hparams["index_name"] = config.index_name
self.hparams["retrieval_num_results"] = config.retrieval_num_results
self.hparams["cache_dir"] = config.cache_dir
# create cache dir
os.makedirs(self.hparams["cache_dir"], exist_ok=True)
# download data
download_retrieval_databases(cache_dir=self.hparams["cache_dir"])
# setup vocabulary
self.vocabulary = RetrievalDatabase("cc12m", self.hparams["cache_dir"])
self._vocab_transform = default_vocabulary_transforms()
@property
def vocab_transform(self) -> Callable:
"""Get image preprocess transform.
The getter wraps the transform in a map_reduce function and applies it to a list of images.
If interested in the transform itself, use `self._vocab_transform`.
"""
vocab_transform = self._vocab_transform
def vocabs_transforms(texts: list[str]) -> list[torch.Tensor]:
return [vocab_transform(text) for text in texts]
return vocabs_transforms
def get_vocabulary(self, images_z: Optional[torch.Tensor] = None) -> list[list[str]]:
"""Get the vocabulary for a batch of images.
Args:
images_z (torch.Tensor): Batch of image embeddings.
"""
num_samples = self.hparams["retrieval_num_results"]
assert images_z is not None
images_z = images_z / images_z.norm(dim=-1, keepdim=True)
images_z = images_z.cpu().detach().numpy().tolist()
if isinstance(images_z[0], float):
images_z = [images_z]
query = np.matrix(images_z).astype("float32")
results = self.vocabulary.query(query, modality="text", num_samples=num_samples)
vocabularies = [[r["caption"] for r in result] for result in results]
return vocabularies
def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor:
"""Forward pass.
Args:
images (dict): Dictionary with the images. The expected keys are:
- pixel_values (torch.Tensor): Pixel values of the images.
alpha (Optional[float]): Alpha value for the interpolation.
"""
alpha = alpha or self.hparams["alpha"]
# forward the images
images["pixel_values"] = images["pixel_values"].to(self.device)
images_z = self.vision_proj(self.vision_encoder(**images)[1])
images_z = images_z / images_z.norm(dim=-1, keepdim=True)
vocabularies = self.get_vocabulary(images_z=images_z)
# encode unfiltered words
unfiltered_words = sum(vocabularies, [])
texts_z = self.processor(unfiltered_words, return_tensors="pt", padding=True)
texts_z["input_ids"] = texts_z["input_ids"][:, :77].to(self.device)
texts_z["attention_mask"] = texts_z["attention_mask"][:, :77].to(self.device)
texts_z = self.language_encoder(**texts_z)[1]
texts_z = self.language_proj(texts_z)
texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)
# generate a text embedding for each image from their unfiltered words
unfiltered_words_per_image = [len(vocab) for vocab in vocabularies]
texts_z = torch.split(texts_z, unfiltered_words_per_image)
texts_z = torch.stack([text_z.mean(dim=0) for text_z in texts_z])
texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)
# filter the words and embed them
vocabularies = self.vocab_transform(vocabularies)
vocabularies = [vocab or ["object"] for vocab in vocabularies]
words = sum(vocabularies, [])
words_z = self.processor(words, return_tensors="pt", padding=True)
words_z = {k: v.to(self.device) for k, v in words_z.items()}
words_z = self.language_encoder(**words_z)[1]
words_z = self.language_proj(words_z)
words_z = words_z / words_z.norm(dim=-1, keepdim=True)
# create a one-hot relation mask between images and words
words_per_image = [len(vocab) for vocab in vocabularies]
col_indices = torch.arange(sum(words_per_image))
row_indices = torch.arange(len(images_z)).repeat_interleave(torch.tensor(words_per_image))
mask = torch.zeros(len(images_z), sum(words_per_image), device=self.device)
mask[row_indices, col_indices] = 1
# get the image and text similarities
images_z = images_z / images_z.norm(dim=-1, keepdim=True)
texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)
words_z = words_z / words_z.norm(dim=-1, keepdim=True)
images_sim = self.logit_scale * images_z @ words_z.T
texts_sim = self.logit_scale * texts_z @ words_z.T
# mask unrelated words
images_sim = torch.masked_fill(images_sim, mask == 0, float("-inf"))
texts_sim = torch.masked_fill(texts_sim, mask == 0, float("-inf"))
# get the image and text predictions
images_p = images_sim.softmax(dim=-1)
texts_p = texts_sim.softmax(dim=-1)
# average the image and text predictions
samples_p = alpha * images_p + (1 - alpha) * texts_p
return {"scores": samples_p, "words": words, "vocabularies": vocabularies}
|