|
import os |
|
from typing import Callable, Optional |
|
|
|
import numpy as np |
|
import torch |
|
from transformers import CLIPModel, CLIPProcessor |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
from .configuration_cased import CaSEDConfig |
|
from .retrieval_cased import RetrievalDatabase, download_retrieval_databases |
|
from .transforms_cased import default_vocabulary_transforms |
|
|
|
|
|
class CaSEDModel(PreTrainedModel): |
|
"""Transformers module for Category Search from External Databases (CaSED). |
|
|
|
Reference: |
|
- Conti et al. Vocabulary-free Image Classification. NeurIPS 2023. |
|
|
|
Args: |
|
config (CaSEDConfig): Configuration class for CaSED. |
|
""" |
|
|
|
config_class = CaSEDConfig |
|
|
|
def __init__(self, config: CaSEDConfig): |
|
super().__init__(config) |
|
|
|
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") |
|
self.vision_encoder = model.vision_model |
|
self.vision_proj = model.visual_projection |
|
self.language_encoder = model.text_model |
|
self.language_proj = model.text_projection |
|
self.logit_scale = model.logit_scale.exp() |
|
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
|
|
self.hparams = {} |
|
self.hparams["alpha"] = config.alpha |
|
self.hparams["index_name"] = config.index_name |
|
self.hparams["retrieval_num_results"] = config.retrieval_num_results |
|
self.hparams["cache_dir"] = config.cache_dir |
|
|
|
|
|
os.makedirs(self.hparams["cache_dir"], exist_ok=True) |
|
|
|
|
|
download_retrieval_databases(cache_dir=self.hparams["cache_dir"]) |
|
|
|
|
|
self.vocabulary = RetrievalDatabase("cc12m", self.hparams["cache_dir"]) |
|
self._vocab_transform = default_vocabulary_transforms() |
|
|
|
@property |
|
def vocab_transform(self) -> Callable: |
|
"""Get image preprocess transform. |
|
|
|
The getter wraps the transform in a map_reduce function and applies it to a list of images. |
|
If interested in the transform itself, use `self._vocab_transform`. |
|
""" |
|
vocab_transform = self._vocab_transform |
|
|
|
def vocabs_transforms(texts: list[str]) -> list[torch.Tensor]: |
|
return [vocab_transform(text) for text in texts] |
|
|
|
return vocabs_transforms |
|
|
|
def get_vocabulary(self, images_z: Optional[torch.Tensor] = None) -> list[list[str]]: |
|
"""Get the vocabulary for a batch of images. |
|
|
|
Args: |
|
images_z (torch.Tensor): Batch of image embeddings. |
|
""" |
|
num_samples = self.hparams["retrieval_num_results"] |
|
|
|
assert images_z is not None |
|
|
|
images_z = images_z / images_z.norm(dim=-1, keepdim=True) |
|
images_z = images_z.cpu().detach().numpy().tolist() |
|
|
|
if isinstance(images_z[0], float): |
|
images_z = [images_z] |
|
|
|
query = np.matrix(images_z).astype("float32") |
|
results = self.vocabulary.query(query, modality="text", num_samples=num_samples) |
|
|
|
vocabularies = [[r["caption"] for r in result] for result in results] |
|
return vocabularies |
|
|
|
def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor: |
|
"""Forward pass. |
|
|
|
Args: |
|
images (dict): Dictionary with the images. The expected keys are: |
|
- pixel_values (torch.Tensor): Pixel values of the images. |
|
alpha (Optional[float]): Alpha value for the interpolation. |
|
""" |
|
alpha = alpha or self.hparams["alpha"] |
|
|
|
|
|
images["pixel_values"] = images["pixel_values"].to(self.device) |
|
images_z = self.vision_proj(self.vision_encoder(**images)[1]) |
|
images_z = images_z / images_z.norm(dim=-1, keepdim=True) |
|
vocabularies = self.get_vocabulary(images_z=images_z) |
|
|
|
|
|
unfiltered_words = sum(vocabularies, []) |
|
texts_z = self.processor(unfiltered_words, return_tensors="pt", padding=True) |
|
texts_z["input_ids"] = texts_z["input_ids"][:, :77].to(self.device) |
|
texts_z["attention_mask"] = texts_z["attention_mask"][:, :77].to(self.device) |
|
texts_z = self.language_encoder(**texts_z)[1] |
|
texts_z = self.language_proj(texts_z) |
|
texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True) |
|
|
|
|
|
unfiltered_words_per_image = [len(vocab) for vocab in vocabularies] |
|
texts_z = torch.split(texts_z, unfiltered_words_per_image) |
|
texts_z = torch.stack([text_z.mean(dim=0) for text_z in texts_z]) |
|
texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True) |
|
|
|
|
|
vocabularies = self.vocab_transform(vocabularies) |
|
vocabularies = [vocab or ["object"] for vocab in vocabularies] |
|
words = sum(vocabularies, []) |
|
words_z = self.processor(words, return_tensors="pt", padding=True) |
|
words_z = {k: v.to(self.device) for k, v in words_z.items()} |
|
words_z = self.language_encoder(**words_z)[1] |
|
words_z = self.language_proj(words_z) |
|
words_z = words_z / words_z.norm(dim=-1, keepdim=True) |
|
|
|
|
|
words_per_image = [len(vocab) for vocab in vocabularies] |
|
col_indices = torch.arange(sum(words_per_image)) |
|
row_indices = torch.arange(len(images_z)).repeat_interleave(torch.tensor(words_per_image)) |
|
mask = torch.zeros(len(images_z), sum(words_per_image), device=self.device) |
|
mask[row_indices, col_indices] = 1 |
|
|
|
|
|
images_z = images_z / images_z.norm(dim=-1, keepdim=True) |
|
texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True) |
|
words_z = words_z / words_z.norm(dim=-1, keepdim=True) |
|
images_sim = self.logit_scale * images_z @ words_z.T |
|
texts_sim = self.logit_scale * texts_z @ words_z.T |
|
|
|
|
|
images_sim = torch.masked_fill(images_sim, mask == 0, float("-inf")) |
|
texts_sim = torch.masked_fill(texts_sim, mask == 0, float("-inf")) |
|
|
|
|
|
images_p = images_sim.softmax(dim=-1) |
|
texts_p = texts_sim.softmax(dim=-1) |
|
|
|
|
|
samples_p = alpha * images_p + (1 - alpha) * texts_p |
|
|
|
return {"scores": samples_p, "words": words, "vocabularies": vocabularies} |
|
|