Abhilashvj's picture
Added custom handler
93910bd
raw
history blame
No virus
2.74 kB
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
import requests
import os
from io import BytesIO
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from diffusers import DiffusionPipeline
import torch
from torch import autocast
import base64
auth_token = "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"
class EndpointHandler():
def __init__(self, path=""):
self.processor = CLIPSegProcessor.from_pretrained("./clipseg-rd64-refined")
self.model = CLIPSegForImageSegmentation.from_pretrained("./clipseg-rd64-refined")
self.pipe = DiffusionPipeline.from_pretrained(
"./",
custom_pipeline="text_inpainting",
segmentation_model=self.model,
segmentation_processor=self.processor,
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=auth_token,
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipe = self.pipe.to(self.device)
def pad_image(self, image):
w, h = image.size
if w == h:
return image
elif w > h:
new_image = Image.new(image.mode, (w, w), (0, 0, 0))
new_image.paste(image, (0, (w - h) // 2))
return new_image
else:
new_image = Image.new(image.mode, (h, h), (0, 0, 0))
new_image.paste(image, ((h - w) // 2, 0))
return new_image
def process_image(self, image, text, prompt):
image = self.pad_image(image)
image = image.resize((512, 512))
with autocast(self.device):
inpainted_image = self.pipe(image=image, text=text, prompt=prompt).images[0]
return inpainted_image
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
class_text = inputs['class_text']
prompt = inputs['prompt']
# run inference pipeline
with autocast(self.device):
image = self.process_image(image, class_text, prompt)
# encode image as base 64
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
# postprocess the prediction
return {"image": img_str.decode()}