rbanfield's picture
Update handler.py
418414b
raw
history blame contribute delete
No virus
1.42 kB
from io import BytesIO
import base64
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.model = CLIPModel.from_pretrained("rbanfield/clip-vit-large-patch14").to("cpu")
self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
def __call__(self, data):
text_input = None
if isinstance(data, dict):
inputs = data.pop("inputs", None)
text_input = inputs.get('text',None)
image_data = BytesIO(base64.b64decode(inputs['image'])) if 'image' in inputs else None
else:
# assuming its an image sent via binary
image_data = BytesIO(data)
if text_input:
processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
return {"embeddings": self.model.get_text_features(**processor).tolist()}
elif image_data:
image = Image.open(image_data)
processor = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
return {"embeddings": self.model.get_image_features(**processor).tolist()}
else:
return {"embeddings": None}