File size: 2,023 Bytes
b926327
 
a4e73a2
884b4bf
b926327
 
 
c6dbef6
b926327
e62633b
884b4bf
 
e62633b
b926327
 
ef53d28
 
b926327
 
a4e73a2
dd31dc4
c04fdf8
a4e73a2
c04fdf8
ed041c3
c04fdf8
 
 
 
 
29107e0
 
c04fdf8
ff5a99d
b926327
a4e73a2
 
 
c6dbef6
29107e0
 
a4e73a2
 
c6dbef6
a4e73a2
 
884b4bf
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from io import BytesIO
import base64
import traceback
import logging

from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger = logging.getLogger(__name__)
logger.setLevel('INFO')

class EndpointHandler():
    def __init__(self, path=""):
        self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

    def __call__(self, data):
        try:

            inputs = data.pop("inputs", None)
            text_input = None
            image_data = None
            if isinstance(inputs, Image.Image):
                logger.info('image sent directly')
                image = inputs
            else:
                text_input = inputs["text"] if "text" in inputs else None
                image_data = inputs['image'] if 'image' in inputs else None

            if image_data is not None:
                logger.info('image is encoded')
                image = Image.open(BytesIO(base64.b64decode(image_data)))

            if text_input:
                processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
                with torch.no_grad():
                    return {"embeddings": self.model.get_text_features(**processor).tolist()}
            elif image:
                # image = Image.open(image_data)
                processor = self.processor(images=image, return_tensors="pt").to(device)
                with torch.no_grad():
                    return {"embeddings": self.model.get_image_features(**processor).tolist()}
            else:
                return {'embeddings':None}
        except Exception as ex:
            logger.error('error doing request: %s', ex)
            logger.exception(ex)
            stack_info = traceback.format_exc()
            logger.error('stack trace:\n%s',stack_info)
            return {'Error':stack_info}