from io import BytesIO import base64 from PIL import Image import torch from transformers import CLIPProcessor, CLIPModel device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EndpointHandler(): def __init__(self, path=""): self.model = CLIPModel.from_pretrained("rbanfield/clip-vit-large-patch14").to("cpu") self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14") def __call__(self, data): text_input = None if isinstance(data, dict): inputs = data.pop("inputs", None) text_input = inputs.get('text',None) image_data = BytesIO(base64.b64decode(inputs['image'])) if 'image' in inputs else None else: # assuming its an image sent via binary image_data = BytesIO(data) if text_input: processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device) with torch.no_grad(): return {"embeddings": self.model.get_text_features(**processor).tolist()} elif image_data: image = Image.open(image_data) processor = self.processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): return {"embeddings": self.model.get_image_features(**processor).tolist()} else: return {"embeddings": None}