|
from io import BytesIO |
|
import base64 |
|
|
|
from PIL import Image |
|
import torch |
|
from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14").to(device) |
|
self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14").to(device) |
|
self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14") |
|
|
|
def __call__(self, data): |
|
|
|
text_input = None |
|
if isinstance(data, dict): |
|
inputs = data.pop("inputs", None) |
|
text_input = inputs.get('text',None) |
|
image_data = BytesIO(base64.b64decode(inputs['image'])) if 'image' in inputs else None |
|
else: |
|
|
|
image_data = BytesIO(data) |
|
|
|
|
|
if text_input: |
|
processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device) |
|
with torch.no_grad(): |
|
return {'embeddings':self.text_model(**processor).pooler_output.tolist()[0]} |
|
elif image_data: |
|
image = Image.open(image_data) |
|
processor = self.processor(images=image, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
return {'embeddings':self.image_model(**processor).image_embeds.tolist()[0]} |
|
else: |
|
return {'embeddings':None} |