from typing import Optional, Tuple, List import torch import torch.nn.functional as F from clip.model import CLIP from transformers import CLIPVisionModelWithProjection from torch.utils.data import DataLoader from torch.utils.data import Dataset from tqdm import tqdm from data_utils import collate_fn from models import Phi if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 @torch.no_grad() def extract_image_features(dataset: Dataset, clip_model: CLIPVisionModelWithProjection, batch_size: Optional[int] = 32, num_workers: Optional[int] = 10) -> Tuple[torch.Tensor, List[str]]: """ Extracts image features from a dataset using a CLIP model. """ # Create data loader loader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn) index_features = [] index_names = [] try: print(f"extracting image features {dataset.__class__.__name__} - {dataset.split}") except Exception as e: pass # Extract features for batch in tqdm(loader): images = batch.get('image') names = batch.get('image_name') if images is None: images = batch.get('reference_image') if names is None: names = batch.get('reference_name') images = images.to(clip_model.device) with torch.no_grad(): batch_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds #.encode_image(images) index_features.append(batch_features.cpu()) index_names.extend(names) index_features = torch.vstack(index_features) return index_features, index_names def contrastive_loss(v1: torch.Tensor, v2: torch.Tensor, temperature: float) -> torch.Tensor: # Based on https://github.com/NVlabs/PALAVRA/blob/main/utils/nv.py v1 = F.normalize(v1, dim=1) v2 = F.normalize(v2, dim=1) numerator = torch.exp(torch.diag(torch.inner(v1, v2)) / temperature) numerator = torch.cat((numerator, numerator), 0) joint_vector = torch.cat((v1, v2), 0) pairs_product = torch.exp(torch.mm(joint_vector, joint_vector.t()) / temperature) denominator = torch.sum(pairs_product - pairs_product * torch.eye(joint_vector.shape[0]).to(device), 0) loss = -torch.mean(torch.log(numerator / denominator)) return loss @torch.no_grad() def extract_pseudo_tokens_with_phi(clip_model: CLIPVisionModelWithProjection, phi: Phi, dataset: Dataset, args) -> Tuple[torch.Tensor, List[str]]: """ Extracts pseudo tokens from a dataset using a CLIP model and a phi model """ data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False, collate_fn=collate_fn) predicted_tokens = [] names_list = [] print(f"Extracting tokens using phi model") for batch in tqdm(data_loader): images = batch.get('image') names = batch.get('image_name') if images is None: images = batch.get('reference_image') if names is None: names = batch.get('reference_name') images = images.to(device) image_features = clip_model(pixel_values=images.half()).image_embeds if args.l2_normalize: image_features = F.normalize(image_features, dim=-1) batch_predicted_tokens = phi(image_features) predicted_tokens.append(batch_predicted_tokens.cpu()) names_list.extend(names) predicted_tokens = torch.vstack(predicted_tokens) return predicted_tokens, names_list @torch.no_grad() def extract_image_features_with_names(clip_model: CLIPVisionModelWithProjection, dataset: Dataset) -> Tuple[torch.Tensor, List[str]]: """ Extracts image features from a dataset using a CLIP model """ data_loader = DataLoader(dataset=dataset, batch_size=32, num_workers=10, pin_memory=False, collate_fn=collate_fn) predicted_tokens = [] names_list = [] print(f"Extracting tokens using phi model") for batch in tqdm(data_loader): images = batch.get('image') names = batch.get('image_name') if images is None: images = batch.get('reference_image') if names is None: names = batch.get('reference_name') images = images.to(device) image_features = clip_model(pixel_values=images.to(clip_model.dtype)).image_embeds #batch_predicted_tokens = phi(image_features) batch_predicted_tokens = image_features predicted_tokens.append(batch_predicted_tokens.cpu()) names_list.extend(names) predicted_tokens = torch.vstack(predicted_tokens) return predicted_tokens, names_list class CustomTensorDataset(Dataset): """ Custom Tensor Dataset which yields image_features and image_names """ def __init__(self, images: torch.Tensor, names: torch.Tensor): self.images = images self.names = names def __getitem__(self, index) -> dict: return {'image': self.images[index], 'image_name': self.names[index] } def __len__(self): return len(self.images) def get_templates(): """ Return a list of templates Same templates as in PALAVRA: https://arxiv.org/abs/2204.01694 """ return [ "This is a photo of a {}", "This photo contains a {}", "A photo of a {}", "This is an illustration of a {}", "This illustration contains a {}", "An illustrations of a {}", "This is a sketch of a {}", "This sketch contains a {}", "A sketch of a {}", "This is a diagram of a {}", "This diagram contains a {}", "A diagram of a {}", "A {}", "We see a {}", "{}", "We see a {} in this photo", "We see a {} in this image", "We see a {} in this illustration", "We see a {} photo", "We see a {} image", "We see a {} illustration", "{} photo", "{} image", "{} illustration", ]