MPA commited on
Commit
6483910
1 Parent(s): c39215b

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +70 -0
  2. requirements.txt +2 -0
handler.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import io
3
+ import base64
4
+ from PIL import Image
5
+ import torch
6
+ import open_clip
7
+
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+
11
+ if torch.backends.mps.is_available():
12
+ device = "mps"
13
+ else:
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ print(f"Using device: {device}")
16
+
17
+
18
+
19
+ class EndpointHandler():
20
+ def __init__(self, path='hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K'):
21
+ self.tokenizer = open_clip.get_tokenizer(path)
22
+ self.model, self.preprocess = open_clip.create_model_from_pretrained(path)
23
+ self.model = self.model.to(device)
24
+
25
+
26
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
27
+ """
28
+ data args:
29
+ inputs (:obj: `str`)
30
+ date (:obj: `str`)
31
+ Return:
32
+ A :obj:`list` | `dict`: will be serialized and returned
33
+ """
34
+ # get inputs
35
+ classes = data.pop('classes')
36
+ base64_image = data.pop('base64_image')
37
+ image_data = base64.b64decode(base64_image)
38
+ image = Image.open(io.BytesIO(image_data))
39
+ image = self.preprocess(image).unsqueeze(0).to(device)
40
+ text = self.tokenizer(classes).to(device)
41
+
42
+ with torch.no_grad():
43
+ image_features = self.model.encode_image(image)
44
+ text_features = self.model.encode_text(text)
45
+ image_features /= image_features.norm(dim=-1, keepdim=True)
46
+ text_features /= text_features.norm(dim=-1, keepdim=True)
47
+
48
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
49
+ return {
50
+ "text_probs": text_probs.tolist()[0],
51
+ "image_features" : image_features.tolist()[0],
52
+ "text_features" : text_features.tolist()[0]
53
+ }
54
+
55
+
56
+
57
+
58
+ if __name__ == "__main__":
59
+ handler = EndpointHandler()
60
+ # read image from disk and decode to base 64
61
+ with open("/Users/mpa/Library/Mobile Documents/com~apple~CloudDocs/mac/work/zillow-scrapper/properties/76031221/1af0f3c34bff2173ab74ae46a5905d4a-cc_ft_1536.jpg", "rb") as f:
62
+ image_data = f.read()
63
+ base64_image = base64.b64encode(image_data).decode("utf-8")
64
+
65
+ data = {
66
+ "classes": ["bedroom", "kitchen", "bathroom", "living room", "dining room", "patio", "backyard", "front yard", "garage", "pool"],
67
+ "base64_image": base64_image
68
+ }
69
+ results = handler(data)
70
+ print('output')
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pillow==10.3.0
2
+ open-clip-torch==2.24.0