File size: 2,413 Bytes
6483910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Dict, List, Any
import io
import base64
from PIL import Image
import torch
import open_clip

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if torch.backends.mps.is_available():
        device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")



class EndpointHandler():
    def __init__(self, path='hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K'):
        self.tokenizer = open_clip.get_tokenizer(path)
        self.model, self.preprocess = open_clip.create_model_from_pretrained(path)
        self.model = self.model.to(device)


    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str`)
            date (:obj: `str`)
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        classes = data.pop('classes')
        base64_image = data.pop('base64_image')
        image_data = base64.b64decode(base64_image)
        image = Image.open(io.BytesIO(image_data))
        image = self.preprocess(image).unsqueeze(0).to(device)
        text = self.tokenizer(classes).to(device)

        with torch.no_grad():
            image_features = self.model.encode_image(image)
            text_features = self.model.encode_text(text)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)

            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        return {
             "text_probs": text_probs.tolist()[0],
             "image_features" : image_features.tolist()[0],
             "text_features" : text_features.tolist()[0]
        }
    



if __name__ == "__main__":
    handler = EndpointHandler()
    # read image from disk and decode to base 64
    with open("/Users/mpa/Library/Mobile Documents/com~apple~CloudDocs/mac/work/zillow-scrapper/properties/76031221/1af0f3c34bff2173ab74ae46a5905d4a-cc_ft_1536.jpg", "rb") as f:
        image_data = f.read()
    base64_image = base64.b64encode(image_data).decode("utf-8")

    data = {
        "classes": ["bedroom", "kitchen", "bathroom", "living room", "dining room", "patio", "backyard", "front yard", "garage", "pool"],
        "base64_image": base64_image
    }
    results = handler(data)
    print('output')