import os import tarfile from pathlib import Path from typing import Optional import faiss import numpy as np import pyarrow as pa import requests import torch from tqdm import tqdm from transformers import CLIPModel, CLIPProcessor from transformers.modeling_utils import PreTrainedModel from .configuration_cased import CaSEDConfig from .transforms_cased import default_vocabulary_transforms DATABASES = { "cc12m": { "url": "https://storage-cased.alessandroconti.me/cc12m.tar.gz", "cache_subdir": "./cc12m/vit-l-14/", }, } class MetadataProvider: """Metadata provider. It uses arrow files to store metadata and retrieve it efficiently. Code reference: - https://github.dev/rom1504/clip-retrieval """ def __init__(self, arrow_folder: Path): arrow_files = [str(a) for a in sorted(arrow_folder.glob("**/*")) if a.is_file()] self.table = pa.concat_tables( [ pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all() for arrow_file in arrow_files ] ) def get(self, ids: np.ndarray, cols: Optional[list] = None): """Get arrow metadata from ids. Args: ids (np.ndarray): Ids to retrieve. cols (Optional[list], optional): Columns to retrieve. Defaults to None. """ if cols is None: cols = self.table.schema.names else: cols = list(set(self.table.schema.names) & set(cols)) t = pa.concat_tables([self.table[i:j] for i, j in zip(ids, ids + 1)]) return t.select(cols).to_pandas().to_dict("records") class CaSEDModel(PreTrainedModel): """Transformers module for Category Search from External Databases (CaSED). Reference: - Conti et al. Vocabulary-free Image Classification. arXiv 2023. Args: config (CaSEDConfig): Configuration class for CaSED. """ config_class = CaSEDConfig def __init__(self, config: CaSEDConfig): super().__init__(config) # load CLIP model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") self.vision_encoder = model.vision_model self.vision_proj = model.visual_projection self.language_encoder = model.text_model self.language_proj = model.text_projection self.logit_scale = model.logit_scale.exp() self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") # load transforms self.vocabulary_transforms = default_vocabulary_transforms() # set hparams self.hparams = {} self.hparams["alpha"] = config.alpha self.hparams["index_name"] = config.index_name self.hparams["retrieval_num_results"] = config.retrieval_num_results # set cache dir self.hparams["cache_dir"] = Path(os.path.expanduser("~/.cache/cased")) os.makedirs(self.hparams["cache_dir"], exist_ok=True) # download databases self.prepare_data() # load faiss indices and metadata providers self.resources = {} for name, items in DATABASES.items(): database_path = self.hparams["cache_dir"] / "databases" / items["cache_subdir"] text_index_fp = database_path / "text.index" metadata_fp = database_path / "metadata/" text_index = faiss.read_index( str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY ) metadata_provider = MetadataProvider(metadata_fp) self.resources[name] = { "device": self.device, "model": "ViT-L-14", "text_index": text_index, "metadata_provider": metadata_provider, } def prepare_data(self): """Download data if needed.""" databases_path = Path(self.hparams["cache_dir"]) / "databases" for name, items in DATABASES.items(): url = items["url"] database_path = Path(databases_path, name) if database_path.exists(): continue # download data target_path = Path(databases_path, name + ".tar.gz") os.makedirs(target_path.parent, exist_ok=True) with requests.get(url, stream=True) as r: r.raise_for_status() total_bytes_size = int(r.headers.get('content-length', 0)) chunk_size = 8192 p_bar = tqdm( desc="Downloading cc12m index", total=total_bytes_size, unit='iB', unit_scale=True, ) with open(target_path, 'wb') as f: for chunk in r.iter_content(chunk_size=chunk_size): f.write(chunk) p_bar.update(len(chunk)) p_bar.close() # extract data tar = tarfile.open(target_path, "r:gz") tar.extractall(target_path.parent) tar.close() target_path.unlink() @torch.no_grad() def query_index(self, sample_z: torch.Tensor) -> torch.Tensor: """Query the external database index. Args: sample_z (torch.Tensor): Sample to query the index. """ # get the index resources = self.resources[self.hparams["index_name"]] text_index = resources["text_index"] metadata_provider = resources["metadata_provider"] # query the index sample_z = sample_z.squeeze(0) sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True) query_input = sample_z.cpu().detach().numpy().tolist() query = np.expand_dims(np.array(query_input).astype("float32"), 0) distances, idxs, _ = text_index.search_and_reconstruct( query, self.hparams["retrieval_num_results"] ) results = idxs[0] nb_results = np.where(results == -1)[0] nb_results = nb_results[0] if len(nb_results) > 0 else len(results) indices = results[:nb_results] distances = distances[0][:nb_results] if len(distances) == 0: return [] # get the metadata results = [] metadata = metadata_provider.get(indices[:20], ["caption"]) for key, (d, i) in enumerate(zip(distances, indices)): output = {} meta = None if key + 1 > len(metadata) else metadata[key] if meta is not None: output.update(meta) output["id"] = i.item() output["similarity"] = d.item() results.append(output) # get the captions only vocabularies = [result["caption"] for result in results] return vocabularies @torch.no_grad() def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor(): """Forward pass. Args: images (dict): Dictionary with the images. The expected keys are: - pixel_values (torch.Tensor): Pixel values of the images. alpha (Optional[float]): Alpha value for the interpolation. """ # forward the images images["pixel_values"] = images["pixel_values"].to(self.device) images_z = self.vision_proj(self.vision_encoder(**images)[1]) vocabularies, samples_p = [], [] for image_z in images_z: image_z = image_z.unsqueeze(0) # generate a single text embedding from the unfiltered vocabulary vocabulary = self.query_index(image_z) text = self.processor(text=vocabulary, return_tensors="pt", padding=True) text["input_ids"] = text["input_ids"][:, :77].to(self.device) text["attention_mask"] = text["attention_mask"][:, :77].to(self.device) text_z = self.language_encoder(**text)[1] text_z = self.language_proj(text_z) text_z = text_z / text_z.norm(dim=-1, keepdim=True) text_z = text_z.mean(dim=0).unsqueeze(0) text_z = text_z / text_z.norm(dim=-1, keepdim=True) # filter the vocabulary, embed it, and get its mean embedding vocabulary = self.vocabulary_transforms(vocabulary) or ["object"] text = self.processor(text=vocabulary, return_tensors="pt", padding=True) text = {k: v.to(self.device) for k, v in text.items()} vocabulary_z = self.language_encoder(**text)[1] vocabulary_z = self.language_proj(vocabulary_z) vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True) # get the image and text predictions image_z = image_z / image_z.norm(dim=-1, keepdim=True) text_z = text_z / text_z.norm(dim=-1, keepdim=True) image_p = (self.logit_scale * image_z @ vocabulary_z.T).softmax(dim=-1) text_p = (self.logit_scale * text_z @ vocabulary_z.T).softmax(dim=-1) # average the image and text predictions alpha = alpha or self.hparams["alpha"] sample_p = alpha * image_p + (1 - alpha) * text_p # save the results samples_p.append(sample_p) vocabularies.append(vocabulary) # get the scores samples_p = torch.stack(samples_p, dim=0) scores = sample_p.cpu() # define the results results = {"vocabularies": vocabularies, "scores": scores} return results