In [33]:
from typing import List
import requests
from PIL import Image
from transformers import CLIPModel, CLIPProcessor, CLIPFeatureExtractor
import torch

In [41]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

In [None]:
class ClipWrapper:
    def __init__(self):
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    def images2vec(self, images: List[Image.Image]) -> torch.Tensor:
        inputs = self.processor(images=images, return_tensors="pt")
        with torch.no_grad():
            model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            image_embeds = self.model.vision_model(**model_inputs)
            clip_vectors = self.model.visual_projection(image_embeds[1])
        return clip_vectors / clip_vectors.norm(dim=-1, keepdim=True)

    def texts2vec(self, texts: List[str]) -> torch.Tensor:
        inputs = self.processor(text=texts, return_tensors="pt", padding=True)
        with torch.no_grad():
            model_inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            text_embeds = self.model.text_model(**model_inputs)
            text_vectors = self.model.text_projection(text_embeds[1])
        return text_vectors / text_vectors.norm(dim=-1, keepdim=True)

In [42]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [65]:
def images2vec(images: List[Image.Image]) -> torch.Tensor:
    inputs = processor(images=images, return_tensors="pt")
    with torch.no_grad():
        model_inputs = {k: v.to(model.device) for k, v in inputs.items()}
        image_embeds = model.vision_model(**model_inputs)
        clip_vectors = model.visual_projection(image_embeds[1])
    return clip_vectors / clip_vectors.norm(dim=-1, keepdim=True)


result = images2vec([image, image])
result.shape

torch.Size([2, 512])

In [70]:
def texts2vec(texts: List[str]) -> torch.Tensor:
    inputs = processor(text=texts, return_tensors="pt", padding=True)
    with torch.no_grad():
        model_inputs = {k: v.to(model.device) for k, v in inputs.items()}
        text_embeds = model.text_model(**model_inputs)
        text_vectors = model.text_projection(text_embeds[1])
    return text_vectors / text_vectors.norm(dim=-1, keepdim=True)


texts2vec(["a photo of a cat", "a photo of a dog"]).shape

torch.Size([2, 512])