Abhilashvj commited on
Commit
93910bd
1 Parent(s): 636182d

Added custom handler

Browse files
Files changed (2) hide show
  1. handler.py +81 -81
  2. test_handler.py +48 -48
handler.py CHANGED
@@ -1,81 +1,81 @@
1
- from typing import Dict, List, Any
2
- from transformers import pipeline
3
- from PIL import Image
4
- import requests
5
- import os
6
- from io import BytesIO
7
- from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
8
- from diffusers import DiffusionPipeline
9
- import torch
10
- from torch import autocast
11
- import base64
12
-
13
-
14
- auth_token = "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"
15
-
16
-
17
- class EndpointHandler():
18
- def __init__(self, path=""):
19
- self.processor = CLIPSegProcessor.from_pretrained("./clipseg-rd64-refined")
20
- self.model = CLIPSegForImageSegmentation.from_pretrained("./clipseg-rd64-refined")
21
-
22
- self.pipe = DiffusionPipeline.from_pretrained(
23
- "./",
24
- custom_pipeline="text_inpainting",
25
- segmentation_model=self.model,
26
- segmentation_processor=self.processor,
27
- revision="fp16",
28
- torch_dtype=torch.float16,
29
- use_auth_token=auth_token,
30
- )
31
-
32
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
- self.pipe = self.pipe.to(self.device)
34
-
35
- def pad_image(self, image):
36
- w, h = image.size
37
- if w == h:
38
- return image
39
- elif w > h:
40
- new_image = Image.new(image.mode, (w, w), (0, 0, 0))
41
- new_image.paste(image, (0, (w - h) // 2))
42
- return new_image
43
- else:
44
- new_image = Image.new(image.mode, (h, h), (0, 0, 0))
45
- new_image.paste(image, ((h - w) // 2, 0))
46
- return new_image
47
-
48
-
49
- def process_image(self, image, text, prompt):
50
- image = self.pad_image(image)
51
- image = image.resize((512, 512))
52
- with autocast(self.device):
53
- inpainted_image = self.pipe(image=image, text=text, prompt=prompt).images[0]
54
- return inpainted_image
55
-
56
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
57
- """
58
- data args:
59
- inputs (:obj: `str`)
60
- date (:obj: `str`)
61
- Return:
62
- A :obj:`list` | `dict`: will be serialized and returned
63
- """
64
- # get inputs
65
- inputs = data.pop("inputs", data)
66
-
67
- # decode base64 image to PIL
68
- image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
69
- class_text = inputs['class_text']
70
- prompt = inputs['prompt']
71
- # run inference pipeline
72
- with autocast(self.device):
73
- image = self.process_image(image, class_text, prompt)
74
-
75
- # encode image as base 64
76
- buffered = BytesIO()
77
- image.save(buffered, format="JPEG")
78
- img_str = base64.b64encode(buffered.getvalue())
79
-
80
- # postprocess the prediction
81
- return {"image": img_str.decode()}
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ import requests
5
+ import os
6
+ from io import BytesIO
7
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
8
+ from diffusers import DiffusionPipeline
9
+ import torch
10
+ from torch import autocast
11
+ import base64
12
+
13
+
14
+ auth_token = "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"
15
+
16
+
17
+ class EndpointHandler():
18
+ def __init__(self, path=""):
19
+ self.processor = CLIPSegProcessor.from_pretrained("./clipseg-rd64-refined")
20
+ self.model = CLIPSegForImageSegmentation.from_pretrained("./clipseg-rd64-refined")
21
+
22
+ self.pipe = DiffusionPipeline.from_pretrained(
23
+ "./",
24
+ custom_pipeline="text_inpainting",
25
+ segmentation_model=self.model,
26
+ segmentation_processor=self.processor,
27
+ revision="fp16",
28
+ torch_dtype=torch.float16,
29
+ use_auth_token=auth_token,
30
+ )
31
+
32
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ self.pipe = self.pipe.to(self.device)
34
+
35
+ def pad_image(self, image):
36
+ w, h = image.size
37
+ if w == h:
38
+ return image
39
+ elif w > h:
40
+ new_image = Image.new(image.mode, (w, w), (0, 0, 0))
41
+ new_image.paste(image, (0, (w - h) // 2))
42
+ return new_image
43
+ else:
44
+ new_image = Image.new(image.mode, (h, h), (0, 0, 0))
45
+ new_image.paste(image, ((h - w) // 2, 0))
46
+ return new_image
47
+
48
+
49
+ def process_image(self, image, text, prompt):
50
+ image = self.pad_image(image)
51
+ image = image.resize((512, 512))
52
+ with autocast(self.device):
53
+ inpainted_image = self.pipe(image=image, text=text, prompt=prompt).images[0]
54
+ return inpainted_image
55
+
56
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
57
+ """
58
+ data args:
59
+ inputs (:obj: `str`)
60
+ date (:obj: `str`)
61
+ Return:
62
+ A :obj:`list` | `dict`: will be serialized and returned
63
+ """
64
+ # get inputs
65
+ inputs = data.pop("inputs", data)
66
+
67
+ # decode base64 image to PIL
68
+ image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
69
+ class_text = inputs['class_text']
70
+ prompt = inputs['prompt']
71
+ # run inference pipeline
72
+ with autocast(self.device):
73
+ image = self.process_image(image, class_text, prompt)
74
+
75
+ # encode image as base 64
76
+ buffered = BytesIO()
77
+ image.save(buffered, format="JPEG")
78
+ img_str = base64.b64encode(buffered.getvalue())
79
+
80
+ # postprocess the prediction
81
+ return {"image": img_str.decode()}
test_handler.py CHANGED
@@ -1,48 +1,48 @@
1
- from handler import EndpointHandler
2
- import json
3
- from typing import List
4
- import requests as r
5
- import base64
6
- import requests as r
7
- import base64
8
- from PIL import Image
9
- from io import BytesIO
10
-
11
- ENDPOINT_URL = ""
12
- HF_TOKEN = ""
13
- def decode_base64_image(image_string):
14
- base64_image = base64.b64decode(image_string)
15
- buffer = BytesIO(base64_image)
16
- return Image.open(buffer)
17
-
18
- # init handler
19
- my_handler = EndpointHandler(path=".")
20
-
21
- # prepare sample payload
22
- path_to_image = "test_images/lal.jpg"
23
- with open(path_to_image, "rb") as i:
24
- b64 = base64.b64encode(i.read())
25
-
26
- payload = {"inputs": {"image": b64.decode("utf-8"), "class_text": "shirt", "prompt": "wedding shirt"}}
27
-
28
- # test the handler
29
- results=my_handler(payload)
30
-
31
- # show results
32
- # print("non_holiday_pred", non_holiday_pred)
33
- # print("holiday_payload", holiday_payload)
34
- decode_base64_image(results["image"]).save("test_results.jpg")
35
- # def predict(path_to_image: str = None, candiates: List[str] = None):
36
- # with open(path_to_image, "rb") as i:
37
- # b64 = base64.b64encode(i.read())
38
-
39
- # payload = {"inputs": {"image": b64.decode("utf-8"), "candiates": candiates}}
40
- # response = r.post(
41
- # ENDPOINT_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"}, json=payload
42
- # )
43
- # return response.json()
44
-
45
-
46
- # prediction = predict(
47
- # path_to_image="palace.jpg", candiates=["sea", "palace", "car", "ship"]
48
- # )
 
1
+ from handler import EndpointHandler
2
+ import json
3
+ from typing import List
4
+ import requests as r
5
+ import base64
6
+ import requests as r
7
+ import base64
8
+ from PIL import Image
9
+ from io import BytesIO
10
+
11
+ ENDPOINT_URL = ""
12
+ HF_TOKEN = ""
13
+ def decode_base64_image(image_string):
14
+ base64_image = base64.b64decode(image_string)
15
+ buffer = BytesIO(base64_image)
16
+ return Image.open(buffer)
17
+
18
+ # init handler
19
+ my_handler = EndpointHandler(path=".")
20
+
21
+ # prepare sample payload
22
+ path_to_image = "test_images/lal.jpg"
23
+ with open(path_to_image, "rb") as i:
24
+ b64 = base64.b64encode(i.read())
25
+
26
+ payload = {"inputs": {"image": b64.decode("utf-8"), "class_text": "shirt", "prompt": "wedding shirt"}}
27
+
28
+ # test the handler
29
+ results=my_handler(payload)
30
+
31
+ # show results
32
+ # print("non_holiday_pred", non_holiday_pred)
33
+ # print("holiday_payload", holiday_payload)
34
+ decode_base64_image(results["image"]).save("test_results.jpg")
35
+ # def predict(path_to_image: str = None, candiates: List[str] = None):
36
+ # with open(path_to_image, "rb") as i:
37
+ # b64 = base64.b64encode(i.read())
38
+
39
+ # payload = {"inputs": {"image": b64.decode("utf-8"), "candiates": candiates}}
40
+ # response = r.post(
41
+ # ENDPOINT_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"}, json=payload
42
+ # )
43
+ # return response.json()
44
+
45
+
46
+ # prediction = predict(
47
+ # path_to_image="palace.jpg", candiates=["sea", "palace", "car", "ship"]
48
+ # )