File size: 2,023 Bytes
b926327 a4e73a2 884b4bf b926327 c6dbef6 b926327 e62633b 884b4bf e62633b b926327 ef53d28 b926327 a4e73a2 dd31dc4 c04fdf8 a4e73a2 c04fdf8 ed041c3 c04fdf8 29107e0 c04fdf8 ff5a99d b926327 a4e73a2 c6dbef6 29107e0 a4e73a2 c6dbef6 a4e73a2 884b4bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from io import BytesIO
import base64
import traceback
import logging
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger = logging.getLogger(__name__)
logger.setLevel('INFO')
class EndpointHandler():
def __init__(self, path=""):
self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
def __call__(self, data):
try:
inputs = data.pop("inputs", None)
text_input = None
image_data = None
if isinstance(inputs, Image.Image):
logger.info('image sent directly')
image = inputs
else:
text_input = inputs["text"] if "text" in inputs else None
image_data = inputs['image'] if 'image' in inputs else None
if image_data is not None:
logger.info('image is encoded')
image = Image.open(BytesIO(base64.b64decode(image_data)))
if text_input:
processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
return {"embeddings": self.model.get_text_features(**processor).tolist()}
elif image:
# image = Image.open(image_data)
processor = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
return {"embeddings": self.model.get_image_features(**processor).tolist()}
else:
return {'embeddings':None}
except Exception as ex:
logger.error('error doing request: %s', ex)
logger.exception(ex)
stack_info = traceback.format_exc()
logger.error('stack trace:\n%s',stack_info)
return {'Error':stack_info}
|