import base64 import logging from io import BytesIO from typing import List import clip import numpy as np import torch from PIL import Image from fastapi import FastAPI, HTTPException from pydantic import BaseModel logging.basicConfig( format='%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s', level=logging.DEBUG, datefmt='%Y-%m-%d %H:%M:%S' ) model, preprocess = clip.load("models/ViT-B-32.pt") model.cpu().eval() app = FastAPI() class Texts(BaseModel): texts_source: List[str] texts_target: List[str] class Images(BaseModel): images: List[str] texts: List[str] class CLIPResponse(BaseModel): similarity: List[List[float]] @app.post("/clip_image_to_text", response_model=CLIPResponse, tags=["CLIP"]) def clip_image_to_text(data: Images) -> CLIPResponse: preprocessed_images = [] for image in data.images: if 'base64,' not in image: raise HTTPException(422, "Image must be in base64") image = BytesIO(base64.b64decode(image.split('base64,')[-1])) image = Image.open(image) preprocessed_images.append(preprocess(image)) image_input = torch.tensor(np.stack(preprocessed_images)).cpu() text_tokens = clip.tokenize(["This is the " + desc for desc in data.texts]).cpu() with torch.no_grad(): image_features = model.encode_image(image_input).float() text_features = model.encode_text(text_tokens).float() image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T logging.debug(f"Similarity: {similarity}") return CLIPResponse(similarity=similarity.tolist()) @app.post("/clip_text_to_text", response_model=CLIPResponse, tags=["CLIP"]) def clip_text_to_text(data: Texts) -> CLIPResponse: text_input = clip.tokenize([f"This is {text}" for text in data.texts_source]).cpu() text_output = clip.tokenize(data.texts_target).cpu() with torch.no_grad(): input_features = model.encode_text(text_input).float() output_features = model.encode_text(text_output).float() input_features /= input_features.norm(dim=-1, keepdim=True) output_features /= output_features.norm(dim=-1, keepdim=True) similarity = output_features.cpu().numpy() @ input_features.cpu().numpy().T logging.debug(f"Similarity: {similarity}") return CLIPResponse(similarity=similarity.tolist()) @app.get("/ping", tags=["TEST"]) def ping(): return "pong"