from typing import Dict, List, Any from transformers import pipeline from PIL import Image import requests import os from io import BytesIO from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from diffusers import DiffusionPipeline import torch from torch import autocast import base64 auth_token = "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" class EndpointHandler(): def __init__(self, path=""): self.processor = CLIPSegProcessor.from_pretrained("./clipseg-rd64-refined") self.model = CLIPSegForImageSegmentation.from_pretrained("./clipseg-rd64-refined") self.pipe = DiffusionPipeline.from_pretrained( "./", custom_pipeline="text_inpainting", segmentation_model=self.model, segmentation_processor=self.processor, revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token, ) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.pipe = self.pipe.to(self.device) def pad_image(self, image): w, h = image.size if w == h: return image elif w > h: new_image = Image.new(image.mode, (w, w), (0, 0, 0)) new_image.paste(image, (0, (w - h) // 2)) return new_image else: new_image = Image.new(image.mode, (h, h), (0, 0, 0)) new_image.paste(image, ((h - w) // 2, 0)) return new_image def process_image(self, image, text, prompt): image = self.pad_image(image) image = image.resize((512, 512)) with autocast(self.device): inpainted_image = self.pipe(image=image, text=text, prompt=prompt).images[0] return inpainted_image def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs inputs = data.pop("inputs", data) # decode base64 image to PIL image = Image.open(BytesIO(base64.b64decode(inputs['image']))) class_text = inputs['class_text'] prompt = inputs['prompt'] # run inference pipeline with autocast(self.device): image = self.process_image(image, class_text, prompt) # encode image as base 64 buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) # postprocess the prediction return {"image": img_str.decode()}