|
import os |
|
import tarfile |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import faiss |
|
import numpy as np |
|
import pyarrow as pa |
|
import requests |
|
import torch |
|
from tqdm import tqdm |
|
from transformers import CLIPModel, CLIPProcessor |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
from .configuration_cased import CaSEDConfig |
|
from .transforms_cased import default_vocabulary_transforms |
|
|
|
DATABASES = { |
|
"cc12m": { |
|
"url": "https://storage-cased.alessandroconti.me/cc12m.tar.gz", |
|
"cache_subdir": "./cc12m/vit-l-14/", |
|
}, |
|
} |
|
|
|
|
|
class MetadataProvider: |
|
"""Metadata provider. |
|
|
|
It uses arrow files to store metadata and retrieve it efficiently. |
|
|
|
Code reference: |
|
- https://github.dev/rom1504/clip-retrieval |
|
""" |
|
|
|
def __init__(self, arrow_folder: Path): |
|
arrow_files = [str(a) for a in sorted(arrow_folder.glob("**/*")) if a.is_file()] |
|
self.table = pa.concat_tables( |
|
[ |
|
pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all() |
|
for arrow_file in arrow_files |
|
] |
|
) |
|
|
|
def get(self, ids: np.ndarray, cols: Optional[list] = None): |
|
"""Get arrow metadata from ids. |
|
|
|
Args: |
|
ids (np.ndarray): Ids to retrieve. |
|
cols (Optional[list], optional): Columns to retrieve. Defaults to None. |
|
""" |
|
if cols is None: |
|
cols = self.table.schema.names |
|
else: |
|
cols = list(set(self.table.schema.names) & set(cols)) |
|
t = pa.concat_tables([self.table[i:j] for i, j in zip(ids, ids + 1)]) |
|
return t.select(cols).to_pandas().to_dict("records") |
|
|
|
|
|
class CaSEDModel(PreTrainedModel): |
|
"""Transformers module for Category Search from External Databases (CaSED). |
|
|
|
Reference: |
|
- Conti et al. Vocabulary-free Image Classification. arXiv 2023. |
|
|
|
Args: |
|
config (CaSEDConfig): Configuration class for CaSED. |
|
""" |
|
|
|
config_class = CaSEDConfig |
|
|
|
def __init__(self, config: CaSEDConfig): |
|
super().__init__(config) |
|
|
|
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") |
|
self.vision_encoder = model.vision_model |
|
self.vision_proj = model.visual_projection |
|
self.language_encoder = model.text_model |
|
self.language_proj = model.text_projection |
|
self.logit_scale = model.logit_scale.exp() |
|
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
|
|
self.vocabulary_transforms = default_vocabulary_transforms() |
|
|
|
|
|
self.hparams = {} |
|
self.hparams["alpha"] = config.alpha |
|
self.hparams["index_name"] = config.index_name |
|
self.hparams["retrieval_num_results"] = config.retrieval_num_results |
|
|
|
|
|
self.hparams["cache_dir"] = Path(os.path.expanduser("~/.cache/cased")) |
|
os.makedirs(self.hparams["cache_dir"], exist_ok=True) |
|
|
|
|
|
self.prepare_data() |
|
|
|
|
|
self.resources = {} |
|
for name, items in DATABASES.items(): |
|
database_path = self.hparams["cache_dir"] / "databases" / items["cache_subdir"] |
|
text_index_fp = database_path / "text.index" |
|
metadata_fp = database_path / "metadata/" |
|
|
|
text_index = faiss.read_index( |
|
str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY |
|
) |
|
metadata_provider = MetadataProvider(metadata_fp) |
|
|
|
self.resources[name] = { |
|
"device": self.device, |
|
"model": "ViT-L-14", |
|
"text_index": text_index, |
|
"metadata_provider": metadata_provider, |
|
} |
|
|
|
def prepare_data(self): |
|
"""Download data if needed.""" |
|
databases_path = Path(self.hparams["cache_dir"]) / "databases" |
|
|
|
for name, items in DATABASES.items(): |
|
url = items["url"] |
|
database_path = Path(databases_path, name) |
|
if database_path.exists(): |
|
continue |
|
|
|
|
|
target_path = Path(databases_path, name + ".tar.gz") |
|
os.makedirs(target_path.parent, exist_ok=True) |
|
with requests.get(url, stream=True) as r: |
|
r.raise_for_status() |
|
total_bytes_size = int(r.headers.get('content-length', 0)) |
|
chunk_size = 8192 |
|
p_bar = tqdm( |
|
desc="Downloading cc12m index", |
|
total=total_bytes_size, |
|
unit='iB', |
|
unit_scale=True, |
|
) |
|
with open(target_path, 'wb') as f: |
|
for chunk in r.iter_content(chunk_size=chunk_size): |
|
f.write(chunk) |
|
p_bar.update(len(chunk)) |
|
p_bar.close() |
|
|
|
|
|
tar = tarfile.open(target_path, "r:gz") |
|
tar.extractall(target_path.parent) |
|
tar.close() |
|
target_path.unlink() |
|
|
|
@torch.no_grad() |
|
def query_index(self, sample_z: torch.Tensor) -> torch.Tensor: |
|
"""Query the external database index. |
|
|
|
Args: |
|
sample_z (torch.Tensor): Sample to query the index. |
|
""" |
|
|
|
resources = self.resources[self.hparams["index_name"]] |
|
text_index = resources["text_index"] |
|
metadata_provider = resources["metadata_provider"] |
|
|
|
|
|
sample_z = sample_z.squeeze(0) |
|
sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True) |
|
query_input = sample_z.cpu().detach().numpy().tolist() |
|
query = np.expand_dims(np.array(query_input).astype("float32"), 0) |
|
|
|
distances, idxs, _ = text_index.search_and_reconstruct( |
|
query, self.hparams["retrieval_num_results"] |
|
) |
|
results = idxs[0] |
|
nb_results = np.where(results == -1)[0] |
|
nb_results = nb_results[0] if len(nb_results) > 0 else len(results) |
|
indices = results[:nb_results] |
|
distances = distances[0][:nb_results] |
|
|
|
if len(distances) == 0: |
|
return [] |
|
|
|
|
|
results = [] |
|
metadata = metadata_provider.get(indices[:20], ["caption"]) |
|
for key, (d, i) in enumerate(zip(distances, indices)): |
|
output = {} |
|
meta = None if key + 1 > len(metadata) else metadata[key] |
|
if meta is not None: |
|
output.update(meta) |
|
output["id"] = i.item() |
|
output["similarity"] = d.item() |
|
results.append(output) |
|
|
|
|
|
vocabularies = [result["caption"] for result in results] |
|
|
|
return vocabularies |
|
|
|
@torch.no_grad() |
|
def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor(): |
|
"""Forward pass. |
|
|
|
Args: |
|
images (dict): Dictionary with the images. The expected keys are: |
|
- pixel_values (torch.Tensor): Pixel values of the images. |
|
alpha (Optional[float]): Alpha value for the interpolation. |
|
""" |
|
|
|
images["pixel_values"] = images["pixel_values"].to(self.device) |
|
images_z = self.vision_proj(self.vision_encoder(**images)[1]) |
|
|
|
vocabularies, samples_p = [], [] |
|
for image_z in images_z: |
|
image_z = image_z.unsqueeze(0) |
|
|
|
|
|
vocabulary = self.query_index(image_z) |
|
text = self.processor(text=vocabulary, return_tensors="pt", padding=True) |
|
text["input_ids"] = text["input_ids"][:, :77].to(self.device) |
|
text["attention_mask"] = text["attention_mask"][:, :77].to(self.device) |
|
text_z = self.language_encoder(**text)[1] |
|
text_z = self.language_proj(text_z) |
|
text_z = text_z / text_z.norm(dim=-1, keepdim=True) |
|
text_z = text_z.mean(dim=0).unsqueeze(0) |
|
text_z = text_z / text_z.norm(dim=-1, keepdim=True) |
|
|
|
|
|
vocabulary = self.vocabulary_transforms(vocabulary) or ["object"] |
|
text = self.processor(text=vocabulary, return_tensors="pt", padding=True) |
|
text = {k: v.to(self.device) for k, v in text.items()} |
|
vocabulary_z = self.language_encoder(**text)[1] |
|
vocabulary_z = self.language_proj(vocabulary_z) |
|
vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True) |
|
|
|
|
|
image_z = image_z / image_z.norm(dim=-1, keepdim=True) |
|
text_z = text_z / text_z.norm(dim=-1, keepdim=True) |
|
image_p = (self.logit_scale * image_z @ vocabulary_z.T).softmax(dim=-1) |
|
text_p = (self.logit_scale * text_z @ vocabulary_z.T).softmax(dim=-1) |
|
|
|
|
|
alpha = alpha or self.hparams["alpha"] |
|
sample_p = alpha * image_p + (1 - alpha) * text_p |
|
|
|
|
|
samples_p.append(sample_p) |
|
vocabularies.append(vocabulary) |
|
|
|
|
|
samples_p = torch.stack(samples_p, dim=0) |
|
scores = sample_p.cpu() |
|
|
|
|
|
results = {"vocabularies": vocabularies, "scores": scores} |
|
|
|
return results |
|
|