abhi001vj commited on
Commit
636182d
1 Parent(s): 1805a7b

added custom handler and test script

Browse files
__pycache__/handler.cpython-310.pyc ADDED
Binary file (2.7 kB). View file
 
__pycache__/handler.cpython-39.pyc ADDED
Binary file (2.68 kB). View file
 
handler.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ import requests
5
+ import os
6
+ from io import BytesIO
7
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
8
+ from diffusers import DiffusionPipeline
9
+ import torch
10
+ from torch import autocast
11
+ import base64
12
+
13
+
14
+ auth_token = "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"
15
+
16
+
17
+ class EndpointHandler():
18
+ def __init__(self, path=""):
19
+ self.processor = CLIPSegProcessor.from_pretrained("./clipseg-rd64-refined")
20
+ self.model = CLIPSegForImageSegmentation.from_pretrained("./clipseg-rd64-refined")
21
+
22
+ self.pipe = DiffusionPipeline.from_pretrained(
23
+ "./",
24
+ custom_pipeline="text_inpainting",
25
+ segmentation_model=self.model,
26
+ segmentation_processor=self.processor,
27
+ revision="fp16",
28
+ torch_dtype=torch.float16,
29
+ use_auth_token=auth_token,
30
+ )
31
+
32
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ self.pipe = self.pipe.to(self.device)
34
+
35
+ def pad_image(self, image):
36
+ w, h = image.size
37
+ if w == h:
38
+ return image
39
+ elif w > h:
40
+ new_image = Image.new(image.mode, (w, w), (0, 0, 0))
41
+ new_image.paste(image, (0, (w - h) // 2))
42
+ return new_image
43
+ else:
44
+ new_image = Image.new(image.mode, (h, h), (0, 0, 0))
45
+ new_image.paste(image, ((h - w) // 2, 0))
46
+ return new_image
47
+
48
+
49
+ def process_image(self, image, text, prompt):
50
+ image = self.pad_image(image)
51
+ image = image.resize((512, 512))
52
+ with autocast(self.device):
53
+ inpainted_image = self.pipe(image=image, text=text, prompt=prompt).images[0]
54
+ return inpainted_image
55
+
56
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
57
+ """
58
+ data args:
59
+ inputs (:obj: `str`)
60
+ date (:obj: `str`)
61
+ Return:
62
+ A :obj:`list` | `dict`: will be serialized and returned
63
+ """
64
+ # get inputs
65
+ inputs = data.pop("inputs", data)
66
+
67
+ # decode base64 image to PIL
68
+ image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
69
+ class_text = inputs['class_text']
70
+ prompt = inputs['prompt']
71
+ # run inference pipeline
72
+ with autocast(self.device):
73
+ image = self.process_image(image, class_text, prompt)
74
+
75
+ # encode image as base 64
76
+ buffered = BytesIO()
77
+ image.save(buffered, format="JPEG")
78
+ img_str = base64.b64encode(buffered.getvalue())
79
+
80
+ # postprocess the prediction
81
+ return {"image": img_str.decode()}
test_handler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+ import json
3
+ from typing import List
4
+ import requests as r
5
+ import base64
6
+ import requests as r
7
+ import base64
8
+ from PIL import Image
9
+ from io import BytesIO
10
+
11
+ ENDPOINT_URL = ""
12
+ HF_TOKEN = ""
13
+ def decode_base64_image(image_string):
14
+ base64_image = base64.b64decode(image_string)
15
+ buffer = BytesIO(base64_image)
16
+ return Image.open(buffer)
17
+
18
+ # init handler
19
+ my_handler = EndpointHandler(path=".")
20
+
21
+ # prepare sample payload
22
+ path_to_image = "test_images/lal.jpg"
23
+ with open(path_to_image, "rb") as i:
24
+ b64 = base64.b64encode(i.read())
25
+
26
+ payload = {"inputs": {"image": b64.decode("utf-8"), "class_text": "shirt", "prompt": "wedding shirt"}}
27
+
28
+ # test the handler
29
+ results=my_handler(payload)
30
+
31
+ # show results
32
+ # print("non_holiday_pred", non_holiday_pred)
33
+ # print("holiday_payload", holiday_payload)
34
+ decode_base64_image(results["image"]).save("test_results.jpg")
35
+ # def predict(path_to_image: str = None, candiates: List[str] = None):
36
+ # with open(path_to_image, "rb") as i:
37
+ # b64 = base64.b64encode(i.read())
38
+
39
+ # payload = {"inputs": {"image": b64.decode("utf-8"), "candiates": candiates}}
40
+ # response = r.post(
41
+ # ENDPOINT_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"}, json=payload
42
+ # )
43
+ # return response.json()
44
+
45
+
46
+ # prediction = predict(
47
+ # path_to_image="palace.jpg", candiates=["sea", "palace", "car", "ship"]
48
+ # )
test_images/lal.jpg ADDED