|
from typing import Dict, List, Any |
|
from transformers import Pipeline |
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
import json |
|
from visual_chatgpt import ImageEditing, Text2Box, Segmenting, Inpainting |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
self.sam = Segmenting('cuda') |
|
self.inpaint = Inpainting('cuda') |
|
self.grounding = Text2Box('cuda') |
|
self.model = ImageEditing(self.grounding,self.sam,self.inpaint) |
|
def __call__(self, data): |
|
|
|
|
|
|
|
|
|
info=data['inputs'] |
|
image=info.pop('image',data) |
|
image=base64.b64decode(image) |
|
raw_image=Image.open(BytesIO(image)).convert('RGB') |
|
target=info.pop('target',data) |
|
replacement=info.pop('replacement',data) |
|
if replacement=="": |
|
return self.model.inference_remove(raw_image,target) |
|
else: |
|
return self.model.inference_replace_sam(raw_image,target,replacement) |
|
|
|
|
|
if __name__=="__main__": |
|
my_handler=EndpointHandler(path='.') |
|
|
|
with open("/home/ubuntu/guoling/1.png",'rb') as img: |
|
image_bytes=img.read() |
|
image_base64=base64.b64encode(image_bytes).decode('utf-8') |
|
target="the pig" |
|
replacement="" |
|
data={ |
|
'inputs':{ |
|
"image":image_base64, |
|
"target":target, |
|
"replacement":replacement |
|
} |
|
} |
|
result=my_handler(data) |
|
result.save("new1.png") |
|
|